文本检索是信息检索领域的核心问题, 其在很多信息检索、NLP下游任务中发挥着非常重要的作用。 近几年, BERT等大规模预训练语言模型的出现使得文本表示效果有了大幅度的提升, 基于预训练语言模型构建的文本检索系统在召回、排序效果上都明显优于传统统计模型。
由于文档候选集合通常比较庞大,实际的工业搜索系统中候选文档数量往往在千万甚至更高的数量级, 为了兼顾效率和准确率,目前的文本检索系统通常是基于召回&排序的多阶段搜索框架。在召回阶段,系统的主要目标是从海量文本中去找到潜在跟query相关的文档,得到较小的候选文档集合(100-1000个)。召回完成后, 排序阶段的模型会对这些召回的候选文档进行更加复杂的排序, 产出最后的排序结果。 本模型为基于预训练的排序阶段模型。
本模型为基于ROM-Base预训练模型的通用领域中文语义相关性模型,模型以一个source sentence以及一个句子列表作为输入,最终输出source sentence与列表中每个句子的相关性得分(0-1,分数越高代表两者越相关)。
本模型主要用于给输入中文查询与文档列表产出相关性分数。用户可以自行尝试输入查询和文档。具体调用方式请参考代码示例。
在安装ModelScope完成之后即可使用语义相关性模型, 该模型以一个source sentence以及一个“sentence_to_compare"(句子列表)作为输入,最终输出source sentence与列表中每个句子的相关性得分(0-1,分数越高代表两者越相关)。 默认每个句子对长度不超过512。
# 可在CPU/GPU环境运行
from modelscope.models import Model
from modelscope.pipelines import pipeline
# Version less than 1.1 please use TextRankingPreprocessor
from modelscope.preprocessors import TextRankingTransformersPreprocessor
from modelscope.utils.constant import Tasks
inputs = {
'source_sentence': ["功和功率的区别"],
'sentences_to_compare': [
"功反映做功多少,功率反映做功快慢。",
"什么是有功功率和无功功率?无功功率有什么用什么是有功功率和无功功率?无功功率有什么用电力系统中的电源是由发电机产生的三相正弦交流电,在交>流电路中,由电源供给负载的电功率有两种;一种是有功功率,一种是无功功率。",
"优质解答在物理学中,用电功率表示消耗电能的快慢.电功率用P表示,它的单位是瓦特(Watt),简称瓦(Wa)符号是W.电流在单位时间内做的功叫做电功率 以灯泡为例,电功率越大,灯泡越亮.灯泡的亮暗由电功率(实际功率)决定,不由通过的电流、电压、电能决定!",
]
}
model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
pipeline_ins = pipeline(task=Tasks.text_ranking, model=model_id, model_revision='v1.1.0')
result = pipeline_ins(input=inputs)
print (result)
# {'scores': [0.9717444181442261, 0.005540850106626749, 0.8629351258277893]}
本模型基于中文公开语义相关性数据集DuReader-Retrieval进行训练,在垂类领域上的排序效果会有降低,请用户自行评测后决定如何使用。
模型采用4张NVIDIA V100机器训练, 主要超参设置如下:
train_epochs=3
max_sequence_length=128
batch_size=128
learning_rate=1e-5
optimizer=AdamW
neg_samples=8
# 需在GPU环境运行
# 加载数据集过程可能由于网络原因失败,请尝试重新运行代码
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
import tempfile
import os
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# load dataset
ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull')
train_ds = ds['train'].to_hf_dataset()
dev_ds = ds['dev'].to_hf_dataset()
model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
def cfg_modify_fn(cfg):
cfg.task = 'text-ranking'
cfg['preprocessor'] = {'type': 'text-ranking'}
cfg['dataset'] = {
'train': {
'type': 'bert',
'query_sequence': 'query',
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'text_fileds': ['text'],
'qid_field': 'query_id'
},
'val': {
'type': 'bert',
'query_sequence': 'query',
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'text_fileds': ['text'],
'qid_field': 'query_id'
},
}
cfg['train']['neg_samples'] = 4
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
cfg.train.max_epochs = 1
cfg.train.train_batch_size = 4
cfg.train.hooks = [{
'type': 'TextLoggerHook',
'interval': 1
}, {
'type': 'IterTimerHook'
}, {
'type': 'EvaluationHook',
'by_epoch': False,
'interval': 1000
}]
return cfg
kwargs = dict(
model=model_id,
train_dataset=train_ds,
work_dir=tmp_dir,
eval_dataset=dev_ds,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.nlp_text_ranking_trainer, default_args=kwargs)
trainer.train()
本模型在DuReader-Retrieval官方召回top50上的排序结果如下:
Model | MRR@10 |
---|---|
ERNIE | 72.9 |
Rom | 73.1 |
# 需在GPU环境运行
# 加载数据集过程可能由于网络原因失败,请尝试重新运行代码
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
import tempfile
import os
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
# load dataset
ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull')
dev_ds = ds['dev'].to_hf_dataset()
model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
def cfg_modify_fn(cfg):
cfg.task = 'text-ranking'
cfg['preprocessor'] = {'type': 'text-ranking'}
cfg['dataset'] = {
'val': {
'type': 'bert',
'query_sequence': 'query',
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'text_fileds': ['text'],
'qid_field': 'query_id'
},
}
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 32
return cfg
kwargs = dict(
model=model_id,
train_dataset=None,
work_dir=tmp_dir,
eval_dataset=dev_ds,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(name=Trainers.nlp_text_ranking_trainer, default_args=kwargs)
# evalute 在单GPU(V100)上需要约四分钟,请耐心等待
metric_values = trainer.evaluate()
print(metric_values)
如果你觉得这个该模型对有所帮助,请考虑引用下面的相关的论文:
@article{Qiu2022DuReader\_retrievalAL,
title={DuReader\_retrieval: A Large-scale Chinese Benchmark for Passage Retrieval from Web Search Engine},
author={Yifu Qiu and Hongyu Li and Yingqi Qu and Ying Chen and Qiaoqiao She and Jing Liu and Hua Wu and Haifeng Wang},
journal={ArXiv},
year={2022},
volume={abs/2203.10232}
}