在这个AI快速发展的时代,你是否也想让AI更好地理解和运用你的专业知识?是否希望搭建一个既能准确回答问题,又能保护数据隐私的AI助手?今天,让我们一起探索如何打造一个基于RAG技术的本地知识库聊天助手。
什么是RAG?为什么我们需要它?
想象一下,你正在和一位博学的助手对话。这位助手不仅拥有强大的对话能力,更重要的是,他能实时查阅你提供的专业文档和资料,给出准确的、有依据的回答。这就是RAG(检索增强生成)技术带来的魔力。
简单来说,RAG就像给AI配备了一个"智能笔记本":当你问问题时,它会先在你提供的文档中搜索相关信息,然后基于这些找到的信息,结合自身的语言能力,生成准确而自然的回答。
随着大语言模型(LLM)的发展,我们确实获得了一个能力强大的对话助手。但在实际应用中,我们发现它们往往会产生"幻觉"——看似流畅自然但实际并不准确的回答。这并不奇怪,因为这些模型的知识来自训练数据,不仅可能已经过时,而且无法包含你的专业领域知识。
RAG技术提供了一个优雅的解决方案:它让AI在回答问题时能够实时查阅你提供的文档资料。这就像一个专业助手,在回答问题时会先翻阅相关的参考资料,确保回答既准确又有依据。这不仅大大降低了AI的"幻觉"问题,还让AI能够基于最新、最相关的信息来回答问题。
为什么需要本地知识库?
在实际工作中,我们经常需要处理一些内部文档、技术资料或者私密信息。直接使用在线AI服务意味着要将这些敏感数据上传到第三方服务器,这显然存在安全隐患。而搭建本地知识库就像给AI配备了一个私人图书馆,它只会查阅和使用你授权的文档,既确保了数据安全,又保证了回答的准确性。
通过本地知识库,你的AI助手可以:
1.准确回答关于公司产品、政策或技术文档的问题
2.确保敏感信息不会泄露到外部
3.始终基于最新的内部资料提供答案
4.降低使用成本,避免频繁的API调用费用
我们将会学到什么?
在这个教程中,我会带你一步步实现这个AI知识库助手。我们会经历两个主要阶段:
第一阶段:从最基础的控制台应用开始,帮助你理解RAG的核心工作原理:
如何处理和存储文档?
如何实现向量检索?
如何生成基于本地知识库的“准确”回答?
第二阶段:升级到一个完整的Web应用,使用Python + Flask + HTML构建界面友好的AI聊天助手
我们会用最简单的方式,一步步完成这个项目。每一步我都会尽量详细解释代码的作用。
先来理解一下本地知识库问答助手的工作原理
知识库是如何工作的?
想象一下我们在阅读一本书时的场景:当我们希望回忆这本书的某个情节时,我们会先找到相关的章节,然后再找到相关的段落,最后再基于这些内容来进行总结。本地知识库问答助手也是采用类似的方式工作。
比如当用户询问"项目A的具体预算是多少?"时,系统会经历以下步骤:
第一步:将问题转化为向量形式。就像是在给这个问题制作一个特殊的"标签",这个"标签"能够帮助我们找到最相关的内容。
第二步:在知识库中检索相似内容。系统会将问题的"标签"与之前存储的所有文档片段进行比较,找出最相关的部分。这就像是在书中使用目录或索引快速找到相关章节。
第三步:结合检索到的内容生成答案。系统不是简单地复制找到的内容,而是会理解这些信息,然后生成一个完整、准确的答案。这就像是一个专业的助手,先阅读相关资料,再用自己的话回答你的问题。
为什么需要向量化存储?
传统的关键词搜索往往难以理解问题的真正含义。例如,当用户问"公司的年度发展计划是什么?"时,相关文档中可能并没有出现"年度发展计划"这个确切的词组,但可能会包含"战略规划"、"年度目标"等相关表述。
通过将文本转换为向量,系统能够理解词语之间的语义关联,从而找到真正相关的信息。这就像是在比较两段文本的"意思"而不是简单的"字面内容"。
实际的检索流程示例
系统会进行如下处理:
1. 文档预处理:首先将产品文档分割成合适的片段,每个片段大约500-1000字符。这样可以确保检索的精确性,同时保持足够的上下文信息。
2. 向量存储:对每个文档片段生成向量表示,并存储在本地数据库中。同时保存原始文本和其对应的向量索引。
3. 相似度检索:当收到用户问题时,系统会:
计算问题的向量表示
在向量数据库中找出最相似的几个片段
将这些相关内容作为上下文提供给大语言模型
4. 答案生成:系统会告诉大语言模型:“提示词+检索到的相关内容+用户问题”,确保生成的答案严格基于检索到的文档内容,当然必要的时候也需要附带历史记录
通过这种方式,系统能够在保证回答准确性的同时,有效降低API调用成本,并确保信息安全性。
理解RAG聊天助手的核心逻辑
上面那堆东西看起来还是太复杂了,让我们通过一个具体的业务场景,来深入理解一个最简单的基于RAG的本地知识库问答系统是如何工作的。
一个典型的业务场景
想象你是一家企业的客服主管,每天都要处理大量关于产品、政策和流程的咨询。你有一份详细的产品手册,但在数百页的文档中快速找到准确的信息并不容易。这正是RAG系统能够帮助的地方。
文档处理的实际流程
假设你有一份产品手册文档:
【产品退换货政策】
第一章:退货规则
1.1 所有产品支持7天无理由退货
1.2 退货时需要提供完整的包装和随附配件
1.3 部分定制化产品除外(详见附录A)
第二章:退货流程
2.1 在线提交退货申请
2.2 等待客服审核确认
2.3 按照指定物流进行退回
...
RAG系统会将这份文档处理成多个独立但有关联的知识片段:
片段1:所有产品支持7天无理由退货,退货时需要提供完整的包装和随附配件。部分定制化产品除外。
片段2:退货流程包括:1.在线提交退货申请 2.等待客服审核确认 3.按照指定物流进行退回。
...
问答过程的完整示例
当用户提出问题时,系统的工作流程是这样的:
1. 用户提问:
"请问你们的退货政策是怎么规定的?需要满足什么条件?"
2. RAG系统检索相关文档片段:
相关度:0.95 - "所有产品支持7天无理由退货,退货时需要提供完整的包装和随附配件。部分定制化产品除外。"
相关度:0.82 - "退货流程包括:1.在线提交退货申请 2.等待客服审核确认 3.按照指定物流进行退回。"
3. 系统构建完整的上下文提示:
你是一个专业的客服助手。请基于以下参考资料,准确回答用户的问题。请只使用参考资料中提供的信息,如果信息不足,请明确告知。
参考资料:
1. 所有产品支持7天无理由退货,退货时需要提供完整的包装和随附配件。部分定制化产品除外。
2. 退货流程包括:1.在线提交退货申请 2.等待客服审核确认 3.按照指定物流进行退回。
用户问题:请问你们的退货政策是怎么规定的?需要满足什么条件?
4. 系统基于特殊提问生成特定回答:
根据我们的政策,我们提供7天无理由退货服务,但需要满足以下条件:
1. 在购买后7天内提出退货申请
2. 商品包装和配件需要保持完整
3. 注意:定制化产品可能不适用此政策
退货具体流程是:
1. 在我们的系统中提交退货申请
2. 等待客服团队审核确认
3. 使用指定的物流渠道退回商品
我们之前实现的多轮对话中的上下文理解也可以复用
RAG系统处理后续的追问:
用户追问:
"定制化产品具体指哪些?"
系统新检索的文档片段:
相关度:0.88 - "附录A:定制化产品包括:1.定制尺寸的家具 2.刻字服务的饰品 3.私人订制的电子产品。这些产品一经生产,除质量问题外不支持退换。"
系统回答:
根据文档说明,定制化产品包括:
1. 定制尺寸的家具
2. 经过刻字服务的饰品
3. 私人订制的电子产品
这些定制产品由于其特殊性,除非存在质量问题,否则不支持退换。
通过这个完整的示例,我们可以看到RAG系统的几个关键优势:
1. 精确的信息定位:能够从大量文档中准确找到与问题最相关的内容
2. 可靠的答案生成:回答严格基于文档内容,避免产生虚假信息
3. 专业的对话能力:既保持了AI流畅的表达,又确保了信息的准确性
4. 灵活的上下文理解:能够处理用户的追问,保持对话的连贯性
技术选型与实现细节
向量模型选择:M3E-Large
在构建本地知识库问答系统时,向量模型的选择至关重要。我们选择了 MokaAI 开源的 M3E-Large 模型作为文本向量化的核心组件。这个选择基于以下几个关键考虑:
优势与特点
1. 中英双语支持:M3E-Large 经过2200万+中文句对数据的训练,能够同时处理中英文本,满足多语言场景需求。
2. 轻量级部署:作为一个小型模型,M3E-Large 可以在CPU环境下运行,降低了硬件要求,使私有化部署更加灵活。
3. 开源可控:模型完全开源,可以根据具体需求进行微调和优化。
4. 性能表现:在文本相似度计算和检索任务中表现出色,与商业模型(如OpenAI的text-embedding系列模型)相比具有竞争力。
向量数据库的存储设计
我们这里采用了混合存储策略,结合SQLite关系数据库和文件系统,实现高效的向量检索和文本存储。
数据库设计
系统使用三个主要表格组织数据:
1. chunks表:存储文档分块信息
CREATE TABLE chunks (
id TEXT PRIMARY KEY,
filename TEXT,
chunk_text TEXT,
chunk_index INTEGER,
created_at TIMESTAMP,
vector_ids TEXT
)
2. vector_mappings表:管理向量ID与文本块的映射关系
CREATE TABLE vector_mappings (
vector_id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id TEXT,
embedding_type TEXT,
FOREIGN KEY (chunk_id) REFERENCES chunks(id)
)
3. processed_files表:追踪文件处理状态
CREATE TABLE processed_files (
filename TEXT PRIMARY KEY,
file_hash TEXT,
processed_at TIMESTAMP,
file_size INTEGER,
chunk_count INTEGER
)
向量存储优化
为了提高检索效率,我们采用了以下策略:
1. 向量文件存储:使用NumPy的二进制格式(.npy)存储向量数据,实现快速加载和检索。
2. 近邻索引:使用sklearn的NearestNeighbors实现向量检索,采用余弦相似度度量。
3. 增量更新:支持文档的增量处理,避免重复计算已处理文档的向量。
文档处理流程
让我们通过一个具体的pdf/word文档处理示例,来理解整个系统的工作流程:
文档预处理
1. 文件验证:
def calculate_file_hash(self, file_path: str) -> str:
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
2. 文本提取:
def read_pdf(self, pdf_path: str) -> str:
try:
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
except Exception as e:
raise ValueError(f"PDF读取失败: {str(e)}")
def read_word(self, doc_path: str) -> str:
try:
doc = Document(doc_path)
text = []
for paragraph in doc.paragraphs:
text.append(paragraph.text)
return '\n'.join(text)
except Exception as e:
raise ValueError(f"Word文档读取失败: {str(e)}")
分块策略
我们采用了智能的文本分块策略,确保每个块都具有完整的语义:
def split_text(self, text: str) -> List[str]:
try:
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + self.chunk_size, text_length)
if end < text_length:
split_point = text.rfind('.', start, end)
if split_point == -1:
for punct in ['. ', '? ', '! ', '\n', '. ', ', ', ' ']:
split_point = text.rfind(punct, start, end)
if split_point != -1:
break
if split_point == -1:
split_point = end
chunks.append(text[start:split_point + 1].strip())
start = split_point + 1
else:
chunks.append(text[start:end].strip())
break
return [chunk for chunk in chunks if chunk]
except MemoryError:
raise ValueError("文件太大,无法处理。请尝试拆分文件后重试。")
except Exception as e:
raise ValueError(f"分割文本时出错: {str(e)}")
向量化与索引
文本块生成后,通过M3E-Large模型进行向量化:
1. 向量生成:调用本地部署的向量服务
2. 索引更新:将新向量添加到现有索引
3. 数据持久化:同步更新数据库和向量文件
同时m3e模型的官方MokaAI开源了基本的使用逻辑:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('moka-ai/m3e-base')
#Our sentences we like to encode
sentences = [
'* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem',
'* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
'* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one'
]
#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)
#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
print("Sentence:", sentence)
print("Embedding:", embedding)
print("")
M3E 系列的所有模型在设计的时候就考虑到完全兼容 sentence-transformers ,所以我们可以通过替换名称字符串的方式在所有支持 sentence-transformers 的项目中无缝使用 M3E Models,比如 chroma, guidance, semantic-kernel 。 (当然我们暂时不需要探究这部分)
现在让我们开始实战部分的代码设计与实现
先封装一个简单的 文本嵌入模型 API 服务
我们之前提到过:OpenAI的API格式因其设计合理性和广泛采用,已成为行业事实标准,所以我们对于文本向量化的标准逻辑也是参考Openai格式的文本嵌入部分的代码:
post https://api.openai.com/v1/embeddings
Openai给出的相关的请求逻辑是:
curl https://api.openai.com/v1/embeddings \
-H "Authorization: Bearer $OPENAI_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"input": "The food was delicious and the waiter...",
"model": "text-embedding-ada-002",
"encoding_format": "float"
}'
或
from openai import OpenAI
client = OpenAI(api_key="<API Key>", base_url="https://api.openai.com/v1")
client.embeddings.create(
model="text-embedding-ada-002",
input="The food was delicious and the waiter...",
encoding_format="float"
)
响应的格式是:
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [
0.0023064255,
-0.009327292,
.... (1536 floats total for ada-002)
-0.0028842222,
],
"index": 0
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
但是如果直接使用OpenAI的嵌入式服务会受到网络延迟与高成本的问题, 所以我们现在才要把将 m3e 模型封装成一个类似于 OpenAI 文本嵌入接口的格式,允许用户通过 API 调用来获取文本的嵌入向量。我们通过这个接口,可以把文本输入传递给模型,模型会返回相应的嵌入向量,格式和 OpenAI 的 text-embedding-ada-002 模型一致。
这部分的代码在底下粘贴一下,其实是无意间找到的,我自己简单调整了一下,还能用暂时就不过多改动了:
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import tiktoken
import numpy as np
from scipy.interpolate import interp1d
from typing import List
from sklearn.preprocessing import PolynomialFeatures
import torch
import os
#环境变量传入
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
# 创建一个FastAPI实例
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建一个HTTPBearer实例
security = HTTPBearer()
# 预加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检测是否有GPU可用,如果有则使用cuda设备,否则使用cpu设备
if torch.cuda.is_available():
print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0))
else:
print('本次加载模型的设备为CPU.')
# model_path = "./m3e-large" # 本地目录路径
model = SentenceTransformer('./m3e-large', device=device)
class EmbeddingRequest(BaseModel):
input: List[str]
model: str
class EmbeddingResponse(BaseModel):
data: list
model: str
object: str
usage: dict
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding('cl100k_base')
num_tokens = len(encoding.encode(string))
return num_tokens
# 插值法
def interpolate_vector(vector, target_length):
original_indices = np.arange(len(vector))
target_indices = np.linspace(0, len(vector)-1, target_length)
f = interp1d(original_indices, vector, kind='linear')
return f(target_indices)
def expand_features(embedding, target_length):
poly = PolynomialFeatures(degree=2)
expanded_embedding = poly.fit_transform(embedding.reshape(1, -1))
expanded_embedding = expanded_embedding.flatten()
if len(expanded_embedding) > target_length:
# 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度
expanded_embedding = expanded_embedding[:target_length]
elif len(expanded_embedding) < target_length:
# 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度
expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding)))
return expanded_embedding
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != sk_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization code",
)
# 计算嵌入向量和tokens数量
embeddings = [model.encode(text) for text in request.input]
# 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度
# embeddings = [interpolate_vector(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
# 如果嵌入向量的维度不为1536,则使用特征扩展法扩展至1536维度
embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings]
# Min-Max normalization
# embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings]
embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
# 将numpy数组转换为列表
embeddings = [embedding.tolist() for embedding in embeddings]
prompt_tokens = sum(len(text.split()) for text in request.input)
total_tokens = sum(num_tokens_from_string(text) for text in request.input)
response = {
"data": [
{
"embedding": embedding,
"index": index,
"object": "embedding"
} for index, embedding in enumerate(embeddings)
],
"model": request.model,
"object": "list",
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": total_tokens,
}
}
return response
if __name__ == "__main__":
uvicorn.run("app:app", host='0.0.0.0', port=6006, workers=1)
具体结构如下:
文本嵌入生成:接受一个文本输入(字符串或字符串列表),使用 m3e 模型(通过 sentence-transformers 库)将其转换为一个 1536 维的嵌入向量。
API 接口:提供 /v1/embeddings 接口,格式和 OpenAI 的 API 非常类似。你可以通过 HTTP 请求向该接口发送文本数据,接口将返回与 OpenAI text-embedding-ada-002 类似的嵌入向量。
认证安全:使用 Bearer Token 验证机制,保证只有提供有效密钥的用户才能访问 API。
数据处理:在生成嵌入向量后,代码还对向量进行了归一化(使其长度为 1),确保返回的向量标准化,方便后续的相似度计算等任务。
使用方式:
请求方式:通过 HTTP POST 请求调用 /v1/embeddings 接口,传递 input(文本输入)和 model(模型名称)参数。
响应格式:响应中包含嵌入向量(一个 1536 维的浮点数数组)和模型相关信息,符合 OpenAI API 的返回格式。
启动这个服务的命令是:(当然要简单点也可以改成flask)
uvicorn embedding_service:app --host 0.0.0.0 --port 6006 --workers 1
再创建一个RAG对话聊天服务
from openai import OpenAI
import fitz
from docx import Document
import os
from typing import List, Dict
import sqlite3
from datetime import datetime
from uuid import uuid4
import numpy as np
from sklearn.neighbors import NearestNeighbors
import pickle
from tqdm import tqdm
import hashlib
class RAGChatBot:
def __init__(self):
self.chat_client = OpenAI(
api_key="sk-xxx",
base_url="https://api.openai.com/v1"
)
self.embedding_client = OpenAI(
api_key="sk-aaabbbcccdddeeefffggghhhiiijjjkkk",
base_url="http://localhost:6006/v1"
)
self.messages = []
self.chunk_size = 500
self.chunk_overlap = 100
self.vector_dim = 1536
self.index_file = "vector_store.pkl"
self.vectors_file = "vectors.npy"
self.vector_index = NearestNeighbors(n_neighbors=5, metric='cosine')
if os.path.exists(self.index_file) and os.path.exists(self.vectors_file):
with open(self.index_file, 'rb') as f:
self.vector_index = pickle.load(f)
self.stored_vectors = np.load(self.vectors_file)
self.vector_index.fit(self.stored_vectors)
else:
self.stored_vectors = np.empty((0, self.vector_dim))
self._init_database()
def _init_database(self):
conn = sqlite3.connect("knowledge_base.db")
cursor = conn.cursor()
# 存储文档块和元数据的表
cursor.execute('''
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
filename TEXT,
chunk_text TEXT,
chunk_index INTEGER,
created_at TIMESTAMP,
vector_ids TEXT
)
''')
# 块ID和向量ID之间的映射表
cursor.execute('''
CREATE TABLE IF NOT EXISTS vector_mappings (
vector_id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id TEXT,
embedding_type TEXT,
FOREIGN KEY (chunk_id) REFERENCES chunks(id)
)
''')
# 新增:已处理文件记录表
cursor.execute('''
CREATE TABLE IF NOT EXISTS processed_files (
filename TEXT PRIMARY KEY,
file_hash TEXT,
processed_at TIMESTAMP,
file_size INTEGER,
chunk_count INTEGER
)
''')
conn.commit()
conn.close()
def calculate_file_hash(self, file_path: str) -> str:
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def is_file_processed(self, file_path: str) -> bool:
"""检查文件是否已经处理过"""
filename = os.path.basename(file_path)
file_hash = self.calculate_file_hash(file_path)
file_size = os.path.getsize(file_path)
conn = sqlite3.connect("knowledge_base.db")
cursor = conn.cursor()
cursor.execute('''
SELECT file_hash, file_size FROM processed_files
WHERE filename = ?
''', (filename,))
result = cursor.fetchone()
conn.close()
if result:
stored_hash, stored_size = result
return stored_hash == file_hash and stored_size == file_size
return False
def mark_file_processed(self, file_path: str, chunk_count: int):
"""将文件标记为已处理"""
filename = os.path.basename(file_path)
file_hash = self.calculate_file_hash(file_path)
file_size = os.path.getsize(file_path)
conn = sqlite3.connect("knowledge_base.db")
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO processed_files
(filename, file_hash, processed_at, file_size, chunk_count)
VALUES (?, ?, ?, ?, ?)
''', (filename, file_hash, datetime.now(), file_size, chunk_count))
conn.commit()
conn.close()
def read_pdf(self, pdf_path: str) -> str:
try:
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
except Exception as e:
raise ValueError(f"PDF读取失败: {str(e)}")
def read_word(self, doc_path: str) -> str:
try:
doc = Document(doc_path)
text = []
for paragraph in doc.paragraphs:
text.append(paragraph.text)
return '\n'.join(text)
except Exception as e:
raise ValueError(f"Word文档读取失败: {str(e)}")
def process_document(self, file_path: str) -> str:
try:
if file_path.lower().endswith('.pdf'):
return self.read_pdf(file_path)
elif file_path.lower().endswith(('.docx', '.doc')):
return self.read_word(file_path)
else:
raise ValueError("不支持的文件格式。仅支持PDF和Word文档。")
except Exception as e:
raise ValueError(f"处理文件失败: {str(e)}")
def split_text(self, text: str) -> List[str]:
try:
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = min(start + self.chunk_size, text_length)
if end < text_length:
split_point = text.rfind('.', start, end)
if split_point == -1:
for punct in ['. ', '? ', '! ', '\n', '. ', ', ', ' ']:
split_point = text.rfind(punct, start, end)
if split_point != -1:
break
if split_point == -1:
split_point = end
chunks.append(text[start:split_point + 1].strip())
start = split_point + 1
else:
chunks.append(text[start:end].strip())
break
return [chunk for chunk in chunks if chunk]
except MemoryError:
raise ValueError("文件太大,无法处理。请尝试拆分文件后重试。")
except Exception as e:
raise ValueError(f"分割文本时出错: {str(e)}")
def get_embeddings(self, texts: List[str]) -> List[List[float]]:
try:
response = self.embedding_client.embeddings.create(
model="m3e-large",
input=texts
)
return [item.embedding for item in response.data]
except Exception as e:
raise ValueError(f"获取向量嵌入失败: {str(e)}")
def index_document(self, file_path: str):
filename = os.path.basename(file_path)
# 检查文件是否已处理
if self.is_file_processed(file_path):
print(f"\n文件 {filename} 已经处理过,跳过处理...")
return
try:
text = self.process_document(file_path)
print(f"\n处理文件: {filename}")
chunks = self.split_text(text)
print(f"文件已分割为 {len(chunks)} 个文本块")
if not chunks:
raise ValueError("未能提取到有效文本内容")
conn = sqlite3.connect("knowledge_base.db")
cursor = conn.cursor()
all_new_embeddings = []
for chunk_idx, chunk in enumerate(tqdm(chunks, desc="生成向量嵌入", unit="块")):
try:
embeddings = self.get_embeddings([chunk])[0]
all_new_embeddings.append(embeddings)
chunk_id = str(uuid4())
vector_id = len(self.stored_vectors) + len(all_new_embeddings) - 1
cursor.execute('''
INSERT INTO chunks (id, filename, chunk_text, chunk_index, created_at, vector_ids)
VALUES (?, ?, ?, ?, ?, ?)
''', (chunk_id, filename, chunk, chunk_idx, datetime.now(), str([vector_id])))
cursor.execute('''
INSERT INTO vector_mappings (vector_id, chunk_id, embedding_type)
VALUES (?, ?, ?)
''', (vector_id, chunk_id, 'default'))
if chunk_idx % 10 == 0:
conn.commit()
except Exception as e:
print(f"\n处理第 {chunk_idx + 1} 个文本块时出错: {str(e)}")
continue
if all_new_embeddings:
new_embeddings = np.array(all_new_embeddings)
self.stored_vectors = np.vstack([self.stored_vectors, new_embeddings])
self.vector_index.fit(self.stored_vectors)
with open(self.index_file, 'wb') as f:
pickle.dump(self.vector_index, f)
np.save(self.vectors_file, self.stored_vectors)
conn.commit()
conn.close()
# 标记文件为已处理
self.mark_file_processed(file_path, len(chunks))
print(f"\n成功完成文件索引: {filename}")
print(f"- 总计处理文本块: {len(chunks)} 个")
print(f"- 向量维度: {self.vector_dim}")
print(f"- 数据库记录已更新")
print("-" * 50)
except Exception as e:
print(f"\n处理文件失败: {filename}")
print(f"错误信息: {str(e)}")
print("请检查文件格式是否正确,以及文件是否可以正常打开。")
print("-" * 50)
def retrieve_context(self, query: str) -> str:
query_embedding = self.get_embeddings([query])[0]
distances, indices = self.vector_index.kneighbors(
[query_embedding],
n_neighbors=min(5, len(self.stored_vectors))
)
similarities = 1 - distances[0]
conn = sqlite3.connect("knowledge_base.db")
cursor = conn.cursor()
chunk_scores: Dict[str, float] = {}
for vector_id, similarity in zip(indices[0], similarities):
cursor.execute('SELECT chunk_id FROM vector_mappings WHERE vector_id = ?', (int(vector_id),))
result = cursor.fetchone()
if result:
chunk_id = result[0]
if chunk_id not in chunk_scores or similarity > chunk_scores[chunk_id]:
chunk_scores[chunk_id] = similarity
contexts = []
for chunk_id, score in sorted(chunk_scores.items(), key=lambda x: x[1], reverse=True):
cursor.execute('SELECT filename, chunk_text FROM chunks WHERE id = ?', (chunk_id,))
result = cursor.fetchone()
if result:
filename, chunk_text = result
contexts.append(f"来自文件 {filename} (相关度: {score:.2f}):\n{chunk_text}")
conn.close()
return "\n\n".join(contexts)
def chat(self, user_input: str):
context = self.retrieve_context(user_input)
system_prompt = f"""你是一个知识库问答助手。请基于以下参考文档回答用户的问题。
如果无法在参考文档中找到答案,请明确告知。回答时请注明信息来源。
参考文档:
{context}"""
if not self.messages:
self.messages.append({"role": "system", "content": system_prompt})
self.messages.append({"role": "user", "content": user_input})
completion = self.chat_client.chat.completions.create(
model="gpt-4o-mini",
messages=self.messages,
stream=True
)
full_response = ""
for chunk in completion:
content = chunk.choices[0].delta.content
if content:
full_response += content
print(content, end='', flush=True)
self.messages.append({"role": "assistant", "content": full_response})
print("\n")
def main():
chatbot = RAGChatBot()
docs_directory = "docs"
if not os.path.exists(docs_directory):
os.makedirs(docs_directory)
print(f"已创建文档目录: {docs_directory}")
doc_count = 0
for filename in os.listdir(docs_directory):
if filename.lower().endswith(('.pdf', '.docx', '.doc')):
file_path = os.path.join(docs_directory, filename)
print(f"正在检查文件 {filename}...")
chatbot.index_document(file_path)
doc_count += 1
if doc_count == 0:
print(f"\n警告:在 {docs_directory} 目录中没有找到任何PDF或Word文档。")
print("请将文档放入该目录后重新运行程序。")
return
print("\n欢迎使用知识库问答助手!")
print("您可以询问任何关于已上传文档的问题。")
print("输入 'quit' 退出对话。")
while True:
user_input = input("\n您: ")
if user_input.lower() == 'quit':
print("感谢使用知识库问答助手,再见!")
break
chatbot.chat(user_input)
if __name__ == "__main__":
main()
RAGChatBot 实现说明
实现一个可靠的本地知识库问答系统,关键在于处理好文档管理、向量检索和对话生成这三个核心环节。让我们详细解析 RAGChatBot 类的设计思路和核心实现。
系统初始化与配置
RAGChatBot 在初始化时建立了两个关键的客户端连接:
def __init__(self):
self.chat_client = OpenAI(
api_key="sk-xxx",
base_url="https://api.openai.com/v1"
)
self.embedding_client = OpenAI(
api_key="sk-yyy",
base_url="http://localhost:6006/v1"
)
chat_client负责与大语言模型通信,处理对话生成任务。而 embedding_client则连接到我们之前搭建的本地文本向量化服务,用于文档和问题的向量表示。
系统采用了混合存储策略,通过SQLite数据库存储文档块和向量映射信息,同时使用文件系统存储向量数据。这种设计既保证了检索效率,又便于管理和扩展:
def _init_database(self):
# chunks表:存储文档块和元数据
cursor.execute('''
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
filename TEXT,
chunk_text TEXT,
chunk_index INTEGER,
created_at TIMESTAMP,
vector_ids TEXT
)
''')
# vector_mappings表:管理向量ID与文本块的映射关系
cursor.execute('''
CREATE TABLE IF NOT EXISTS vector_mappings (
vector_id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id TEXT,
embedding_type TEXT,
FOREIGN KEY (chunk_id) REFERENCES chunks(id)
)
''')
文档处理与向量索引
文档处理采用了智能分块策略,通过 chunk_size 和 chunk_overlap 参数控制分块大小和重叠程度,确保语义完整性:
self.chunk_size = 500 # 默认分块大小
self.chunk_overlap = 100 # 分块重叠长度
self.vector_dim = 1536 # 向量维度
index_document 方法实现了文档索引的完整流程:
1. 文件查重:通过计算文件哈希值避免重复处理
2. 文本提取:支持 PDF 和 Word 格式的文档读取
3. 智能分块:确保每个文本块的语义完整性
4. 向量生成:调用本地向量服务生成文本嵌入
5. 数据存储:同步更新数据库和向量存储
检索与对话生成
retrieve_context 方法使用余弦相似度进行向量检索,找出与用户问题最相关的文档片段:
def retrieve_context(self, query: str) -> str:
query_embedding = self.get_embeddings([query])[0]
distances, indices = self.vector_index.kneighbors(
[query_embedding],
n_neighbors=min(5, len(self.stored_vectors))
)
chat 方法将检索到的上下文与用户问题相结合,构建合适的提示词,指导大语言模型生成准确的回答:
def chat(self, user_input: str):
context = self.retrieve_context(user_input)
system_prompt = f"""你是一个知识库问答助手。
请基于以下参考文档回答用户的问题。
如果无法在参考文档中找到答案,请明确告知。
回答时请注明信息来源。
参考文档:
{context}"""
错误处理与优化
系统实现了完善的错误处理机制,确保在文件读取、向量生成等环节出现异常时能够优雅降级:
try:
embeddings = self.get_embeddings([chunk])[0]
except Exception as e:
print(f"\n处理第 {chunk_idx + 1} 个文本块时出错: {str(e)}")
continue
同时,通过批量提交和定期保存等策略优化了数据处理性能:
if chunk_idx % 10 == 0:
conn.commit()
通过这种模块化和健壮的设计,我们的 RAGChatBot 能够高效处理文档,准确回答问题,同时保证了系统的可靠性和可扩展性。无论是处理简单的产品咨询,还是复杂的技术问答,都能提供准确且有依据的回答。
运行效果: