文本分割的粒度

缺陷:

  1. 粒度太大可能导致检索不精准,粒度太小可能导致信息不全面
  2. 问题的答案可能跨越两个片段

改进:按一定粒度,部分重叠式的切割文本,使上下文更完整

pip install nltk
from nltk.tokenize import sent_tokenize
import json
 
 
# chunk_size 一般根据文档内容或大小来设置
# overlap_size 一般设置 chunk_size 大小的 10%-20% 之间
def split_text(paragraphs, chunk_size=2000, overlap_size=300):
    '''按指定 chunk_size 和 overlap_size 交叠割文本'''
    sentences = [s.strip() for p in paragraphs for s in sent_tokenize(p)]
    chunks = []
    i = 0
    while i < len(sentences):
        chunk = sentences[i]
        overlap = ''
        prev_len = 0
        prev = i - 1
        # 向前计算重叠部分
        while prev >= 0 and len(sentences[prev])+len(overlap) <= overlap_size:
            overlap = sentences[prev] + ' ' + overlap
            prev -= 1
        chunk = overlap+chunk
        next = i + 1
        # 向后计算当前 chunk
        while next < len(sentences) and len(sentences[next])+len(chunk) <= chunk_size:
            chunk = chunk + ' ' + sentences[next]
            next += 1
        chunks.append(chunk)
        i = next
    return chunks

此处 sent_tokenize 为针对英文的实现,针对中文的实现参考:

import re
import jieba
import nltk
from nltk.corpus import stopwords
 
nltk.download('stopwords')
 
def to_keywords(input_string):
    """将句子转成检索关键词序列"""
    # 按搜索引擎模式分词
    word_tokens = jieba.cut_for_search(input_string)
    # 加载停用词表
    stop_words = set(stopwords.words('chinese'))
    # 去除停用词
    filtered_sentence = [w for w in word_tokens if not w in stop_words]
    return ' '.join(filtered_sentence)
 
def sent_tokenize(input_string):
    """按标点断句"""
    # 按标点切分
    sentences = re.split(r'(?<=[。!?;?!])', input_string)
    # 去掉空字符串
    return [sentence for sentence in sentences if sentence.strip()]

上面的切割方法比较基础:

  1. 基于某些规则来切分(如:\n\n\n、句号、问号等)
  2. 基于字符数 chunk_size 和 overlap_size 来重叠切割

对于复杂的文本或业务需求,需要:

  • NSP 任务来进行微调训练一个切分模型(拿自己业务的数据来投喂)

能够实现更好的切分效果,如:A 和 B 两个句子(段落)是否有关系,有关系则进行合并

检索后排序

问题:有时,最合适的答案不一定排在检索的最前面

方案:

  1. 检索时过召回一部分文本(TopK 大一点)
  2. 通过一个排序模型对 query 和 document 重新打分排序

从向量数据库检索出来多个和问题关联的文档通过一个 Re-Ranker 的模型进行重新打分排序

pip install sentence_transformers

将 Re-Ranker 模型下载到电脑本地:

from sentence_transformers import CrossEncoder
 
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) # 英文,模型较小
# model = CrossEncoder('BAAI/bge-reranker-large', max_length=512) # 多语言,国产,模型较大

测试:

user_query = "how safe is llama 2"
# user_query = "llama 2 安全性如何"
scores = model.predict([(user_query, doc) for doc in search_results['documents'][0]])
# 按得分排序
sorted_list = sorted(zip(scores, search_results['documents'][0]), key=lambda x: x[0], reverse=True)
for score, doc in sorted_list:
    print(f"{score}\t{doc}\n")
 
# 6.613734722137451	We believe that the open release of LLMs, when done safely, will be a net benefit to society. Like all LLMs, Llama 2 is a new technology that carries potential risks with use (Bender et al., 2021b; Weidinger et al., 2021; Solaiman et al., 2023).
# 
# 5.310717582702637    In this work, we develop and release Llama 2, a family of pretrained and fine-tuned LLMs, Llama 2 and Llama 2-Chat, at scales up to 70B parameters. On the series of helpfulness and safety benchmarks we tested, Llama 2-Chat models generally perform better than existing open-source models.
# 
# 4.709955215454102    We provide a responsible use guide¶ and code examples‖ to facilitate the safe deployment of Llama 2 and Llama 2-Chat. More details of our responsible release strategy can be found in Section 5.3.
# 
# 4.5439653396606445    We also share novel observations we made during the development of Llama 2 and Llama 2-Chat, such as the emergence of tool usage and temporal organization of knowledge. Figure 3: Safety human evaluation results for Llama 2-Chat compared to other open-source and closed source models.
# 
# 4.0338897705078125    Additionally, these safety evaluations are performed using content standards that are likely to be biased towards the Llama 2-Chat models. We are releasing the following models to the general public for research and commercial use‡: 1.
# 

如果不想把 Re-Ranker 模型下载到本地,有一些在线的 API 服务:

混合检索 (Hybrid Search)

实际生产中,传统的关键字检索(稀疏表示)与向量检索(稠密表示)各有优劣

举个具体例子,比如文档中包含很长的专有名词,关键字检索往往更精准而向量检索容易引入概念混淆

# 背景说明:在医学中“小细胞肺癌”和“非小细胞肺癌”是两种不同的癌症
 
query = "非小细胞肺癌的患者"
 
documents = [
    "玛丽患有肺癌,癌细胞已转移",
    "刘某肺癌I期",
    "张某经诊断为非小细胞肺癌III期",
    "小细胞肺癌是肺癌的一种"
]
 
query_vec = get_embeddings([query])[0]
doc_vecs = get_embeddings(documents)
 
print("Cosine distance:")
for vec in doc_vecs:
    print(cos_sim(query_vec, vec))
 
# Cosine distance:
# 0.8915956814209096
# 0.8902380296876382
# 0.9043403228477503
# 0.9136486327152477
 
# 这里对于专有名词,“小细胞肺癌”和“非小细胞肺癌”都是混在一起的(3、4 相关性高)

所以,有时候我们需要结合不同的检索算法,来达到比单一检索算法更优的效果,这就是混合检索

混合检索的核心是:综合文档 在不同检索算法下的排序名次(rank),为其生成最终排序

一个最常用的算法叫 Reciprocal Rank Fusion (RRF)

其中 表示所有使用的检索算法的集合, 表示使用算法 检索时,文档 的排序, 是个常数

很多向量数据库都支持混合检索(见主流向量数据库功能对比),也可以根据上述原理自己实现

手写个简单的例子

注意:需要安装好 Elastic Search Server,并启动!

# Elastic Search python 客户端
pip install elasticsearch7
  1. 基于关键字检索的排序
import time
 
class MyEsConnector:
    def __init__(self, es_client, index_name, keyword_fn):
        self.es_client = es_client
        self.index_name = index_name
        self.keyword_fn = keyword_fn
 
    def add_documents(self, documents):
        '''文档灌库'''
        if self.es_client.indices.exists(index=self.index_name):
            self.es_client.indices.delete(index=self.index_name)
        self.es_client.indices.create(index=self.index_name)
        actions = [
            {
                "_index": self.index_name,
                "_source": {
                    "keywords": self.keyword_fn(doc),
                    "text": doc,
                    "id": f"doc_{i}"
                }
            }
            for i, doc in enumerate(documents)
        ]
        helpers.bulk(self.es_client, actions)
        time.sleep(1)
 
    def search(self, query_string, top_n=3):
        '''检索'''
        search_query = {
            "match": {
                "keywords": self.keyword_fn(query_string)
            }
        }
        res = self.es_client.search(
            index=self.index_name, query=search_query, size=top_n)
        return {
            hit["_source"]["id"]: {
                "text": hit["_source"]["text"],
                "rank": i,
            }
            for i, hit in enumerate(res["hits"]["hits"])
        }
from chinese_utils import to_keywords  # 使用中文的关键字提取函数
 
# 引入配置文件
ELASTICSEARCH_BASE_URL = os.getenv('ELASTICSEARCH_BASE_URL')
ELASTICSEARCH_PASSWORD = os.getenv('ELASTICSEARCH_PASSWORD')
ELASTICSEARCH_NAME= os.getenv('ELASTICSEARCH_NAME')
 
es = Elasticsearch(
    hosts=[ELASTICSEARCH_BASE_URL],
    http_auth=(ELASTICSEARCH_NAME, ELASTICSEARCH_PASSWORD),  # 用户名,密码
)
 
# 创建 ES 连接器
es_connector = MyEsConnector(es, "demo_es_rrf", to_keywords)
# 文档灌库
es_connector.add_documents(documents)
# 关键字检索
keyword_search_results = es_connector.search(query, 3)
 
print(json.dumps(keyword_search_results, indent=4, ensure_ascii=False))
  1. 基于向量检索的排序
# 创建向量数据库连接器
vecdb_connector = MyVectorDBConnector("demo_vec_rrf", get_embeddings)
 
# 文档灌库
vecdb_connector.add_documents(documents)
 
# 向量检索
vector_search_results = {
    "doc_"+str(documents.index(doc)): {
        "text": doc,
        "rank": i
    }
    for i, doc in enumerate(
        vecdb_connector.search(query, 3)["documents"][0]
    )
}  # 把结果转成跟上面关键字检索结果一样的格式
 
print(json.dumps(vector_search_results, indent=4, ensure_ascii=False))
  1. 基于 RRF 的融合排序

参考资料:https://learn.microsoft.com/zh-cn/azure/search/hybrid-search-ranking

def rrf(ranks, k=1):
	ret = {}
	# 遍历每次的排序结果
	for rank in ranks:
	    # 遍历排序中每个元素
	    for id, val in rank.items():
	        if id not in ret:
	            ret[id] = {"score": 0, "text": val["text"]}
	        # 计算 RRF 得分
	        ret[id]["score"] += 1.0/(k+val["rank"])
	# 按 RRF 得分排序,并返回
	return dict(sorted(ret.items(), key=lambda item: item[1]["score"], reverse=True))
import json
 
# 融合两次检索的排序结果
reranked = rrf([keyword_search_results, vector_search_results])
 
print(json.dumps(reranked, indent=4, ensure_ascii=False))

PDF 文档中的表格怎么处理

1. 将每页 PDF 转成图片

pip install PyMuPDF
pip install matplotlib
import os
import fitz
from PIL import Image
 
def pdf2images(pdf_file):
    '''将 PDF 每页转成一个 PNG 图像'''
    # 保存路径为原 PDF 文件名(不含扩展名)
    output_directory_path, _ = os.path.splitext(pdf_file)
    
    if not os.path.exists(output_directory_path):
        os.makedirs(output_directory_path)
    
    # 加载 PDF 文件
    pdf_document = fitz.open(pdf_file)
    
    # 每页转一张图
    for page_number in range(pdf_document.page_count):
        # 取一页
        page = pdf_document[page_number]
    
        # 转图像
        pix = page.get_pixmap()
    
        # 从位图创建 PNG 对象
        image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    
        # 保存 PNG 文件
        image.save(f"./{output_directory_path}/page_{page_number + 1}.png")
    
    # 关闭 PDF 文件
    pdf_document.close()
from PIL import Image
import os
import matplotlib.pyplot as plt
 
def show_images(dir_path):
    '''显示目录下的 PNG 图像'''
    for file in os.listdir(dir_path):
        if file.endswith('.png'):
            # 打开图像
            img = Image.open(os.path.join(dir_path, file)) 
 
            # 显示图像
            plt.imshow(img)
            plt.axis('off')  # 不显示坐标轴
            plt.show()
pdf2images("llama2_page8.pdf")
show_images("llama2_page8")

2. 识别图片中的表格

class MaxResize(object):
	'''缩放图像'''
	def __init__(self, max_size=800):
	    self.max_size = max_size
	
	def __call__(self, image):
	    width, height = image.size
	    current_max_size = max(width, height)
	    scale = self.max_size / current_max_size
	    resized_image = image.resize(
	        (int(round(scale * width)), int(round(scale * height)))
	    )
	
	    return resized_image
pip install torchvision
pip install transformers
pip install timm
import torchvision.transforms as transforms
 
# 图像预处理
detection_transform = transforms.Compose(
    [
        MaxResize(800),
        # 将原始的 PILImage 格式的数据格式化为可被 pytorch 快速处理的张量类型
        transforms.ToTensor(),
        # 固定的参数,是当前找到的最佳实践的归一参数
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
from transformers import AutoModelForObjectDetection
 
# 加载 TableTransformer 模型
model = AutoModelForObjectDetection.from_pretrained(
    "microsoft/table-transformer-detection"
)
# 识别后的坐标换算与后处理
 
def box_cxcywh_to_xyxy(x):
    '''坐标转换'''
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)
 
 
def rescale_bboxes(out_bbox, size):
    '''区域缩放'''
    width, height = size
    boxes = box_cxcywh_to_xyxy(out_bbox)
    boxes = boxes * torch.tensor(
        [width, height, width, height], dtype=torch.float32
    )
    return boxes
 
 
def outputs_to_objects(outputs, img_size, id2label):
    '''从模型输出中取定位框坐标'''
    m = outputs.logits.softmax(-1).max(-1)
    pred_labels = list(m.indices.detach().cpu().numpy())[0]
    pred_scores = list(m.values.detach().cpu().numpy())[0]
    pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
    pred_bboxes = [
        elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)
    ]
 
    objects = []
    for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
        class_label = id2label[int(label)]
        if not class_label == "no object":
            objects.append(
                {
                    "label": class_label,
                    "score": float(score),
                    "bbox": [float(elem) for elem in bbox],
                }
            )
 
    return objects
import torch
 
# 识别表格,并将表格部分单独存为图像文件
 
def detect_and_crop_save_table(file_path):
    # 加载图像(PDF页)    
    image = Image.open(file_path)
 
    filename, _ = os.path.splitext(os.path.basename(file_path))
 
    # 输出路径
    cropped_table_directory = os.path.join(os.path.dirname(file_path), "table_images")
 
    if not os.path.exists(cropped_table_directory):
        os.makedirs(cropped_table_directory)
 
    # 预处理
    pixel_values = detection_transform(image).unsqueeze(0)
 
    # 识别表格
    with torch.no_grad():
        outputs = model(pixel_values)
 
    # 后处理,得到表格子区域
    id2label = model.config.id2label
    id2label[len(model.config.id2label)] = "no object"
    detected_tables = outputs_to_objects(outputs, image.size, id2label)
 
    print(f"检测到的表格数量:{len(detected_tables)}")
 
    for idx in range(len(detected_tables)):
        # 将识别从的表格区域单独存为图像
        cropped_table = image.crop(detected_tables[idx]["bbox"])
        cropped_table.save(os.path.join(cropped_table_directory,f"{filename}_{idx}.png"))
detect_and_crop_save_table("llama2_page8/page_1.png")
show_images("llama2_page8/table_images")

检测到的表格数量:

3. 基于 GPT-4 Vision API 做表格问答

import base64
from openai import OpenAI
 
client = OpenAI()
 
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')
 
def image_qa(query, image_path):
    base64_image = encode_image(image_path)
    response = client.chat.completions.create(
        model="gpt-4o",
        temperature=0,
        seed=42,
        messages=[{
            "role": "user",
              "content": [
                  {"type": "text", "text": query},
                  {
                      "type": "image_url",
                      "image_url": {
                          "url": f"data:image/jpeg;base64,{base64_image}",
                      },
                  },
              ],
        }],
    )
 
    return response.choices[0].message.content
response = image_qa("哪个模型在 AGI Eval 数据集上表现最好。得分多少", "llama2_page8/table_images/page_1_0.png")
print(response)
 
# 在 AGI Eval 数据集上表现最好的模型是 LLaMA 2 70B,得分为 54.2。

4. 用 GPT-4 Vision 生成表格图像描述,并向量化用于检索

import chromadb
from chromadb.config import Settings
 
 
class NewVectorDBConnector:
    def __init__(self, collection_name, embedding_fn):
        chroma_client = chromadb.Client(Settings(allow_reset=True))
 
        # 为了演示,实际不需要每次 reset()
        chroma_client.reset()
 
        # 创建一个 collection
        self.collection = chroma_client.get_or_create_collection(
            name=collection_name)
        self.embedding_fn = embedding_fn
 
    def add_documents(self, documents):
        '''向 collection 中添加文档与向量'''
        self.collection.add(
            embeddings=self.embedding_fn(documents),  # 每个文档的向量
            documents=documents,  # 文档的原文
            ids=[f"id{i}" for i in range(len(documents))]  # 每个文档的 id
        )
 
    def add_images(self, image_paths):
        '''向 collection 中添加图像'''
        documents = [
            image_qa("请简要描述图片中的信息", image)
            for image in image_paths
        ]
        self.collection.add(
            embeddings=self.embedding_fn(documents),  # 每个文档的向量
            documents=documents,  # 文档的原文
            ids=[f"id{i}" for i in range(len(documents))],  # 每个文档的 id
            metadatas=[{"image": image} for image in image_paths] # 用 metadata 标记源图像路径
        )
 
    def search(self, query, top_n):
        '''检索向量数据库'''
        results = self.collection.query(
            query_embeddings=self.embedding_fn([query]),
            n_results=top_n
        )
        return results
images = []
dir_path = "llama2_page8/table_images"
for file in os.listdir(dir_path):
    if file.endswith('.png'):
        # 打开图像
        images.append(os.path.join(dir_path, file))
 
new_db_connector = NewVectorDBConnector("table_demo",get_embeddings)
new_db_connector.add_images(images)
query  = "哪个模型在AGI Eval数据集上表现最差。得分多少"
 
results = new_db_connector.search(query, 1)
metadata = results["metadatas"][0]
print("====检索结果====")
print(metadata)
print("====回复====")
response = image_qa(query,metadata[0]["image"])
print(response)
 
# ====检索结果====
# [{'image': 'llama2_page8/table_images\\page_1_0.png'}]
# ====回复====
# 从表格中可以看出,在 AGI Eval 数据集上表现最差的模型是 **Falcon 7B**,其得分为 **21.2**。

一些面向 RAG 的文档解析辅助工具

以上一步步手写代码可以了解其原理,实际项目中可以使用现有的工具:

  • PyMuPDF:PDF 文件处理基础库,带有基于规则的表格与图像抽取(不准)
  • RAGFlow:一款基于深度文档理解构建的开源 RAG 引擎,支持多种文档格式(火爆)
  • Unstructured.io:一个开源 + SaaS 形式的文档解析库,支持多种文档格式
  • LlamaParse:付费 API 服务,由 LlamaIndex 官方提供,解析不保证 100% 准确,实测偶有文字丢失或错位发生
  • Mathpix:付费 API 服务,效果较好,可解析段落结构、表格、公式等(贵)

在工程上,PDF 解析本身是个复杂且琐碎的工作。以上工具都不完美,建议在自己实际场景测试后选择使用