StructBERT中文文本相似度模型是在structbert-base-chinese预训练模型的基础上,用atec、bq_corpus、chineseSTS、lcqmc、paws-x-zh五个数据集(52.5w条数据,正负比例0.48:0.52)训练出来的相似度匹配模型。由于license权限问题,目前只上传了BQ_Corpus、chineseSTS、LCQMC这三个数据集。
其他数据集:
模型基于Structbert-base-chinese,按照BERT文本对分类的方式,在atec、bq_corpus、chineseSTS、lcqmc、paws-x-zh五个数据集(52.5w条数据)上进行微调。
你可以使用StructBERT中文文本相似度模型,对通用领域的文本相似度任务进行推理。
输入形如(文本A,文本B)的文本对数据,模型会给出该文本对的是否相似的标签(不相似, 相似)以及相应的概率。
在安装完成ModelScope-lib,请参考 modelscope环境安装 。
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
similarity_pipeline = pipeline(Tasks.sentence_similarity, 'damo/nlp_structbert_sentence-similarity_chinese-base')
similarity_pipeline(input=('商务职业学院和财经职业学院哪个好?', '商务职业学院商务管理在哪个校区?'))
import os.path as osp
from modelscope.trainers import build_trainer
from modelscope.msdatasets import MsDataset
from modelscope.utils.hub import read_config
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
dataset_id = 'BQ_Corpus'
WORK_DIR = 'workspace'
cfg = read_config(model_id)
cfg.train.max_epochs = 2
cfg.train.work_dir = WORK_DIR
cfg.train.hooks = cfg.train.hooks = [{
'type': 'TextLoggerHook',
'interval': 100
}]
cfg_file = osp.join(WORK_DIR, 'train_config.json')
cfg.dump(cfg_file)
train_dataset = MsDataset.load(dataset_id, namespace='DAMO_NLP', split='train').to_hf_dataset()
eval_dataset = MsDataset.load(dataset_id, namespace='DAMO_NLP', split='validation').to_hf_dataset()
# map float to index
def map_labels(examples):
map_dict = {0: "不相似", 1: "相似"}
examples['label'] = map_dict[int(examples['label'])]
return examples
train_dataset = train_dataset.map(map_labels)
eval_dataset = eval_dataset.map(map_labels)
kwargs = dict(
model=model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
cfg_file=cfg_file)
trainer = build_trainer(name='nlp-base-trainer',default_args=kwargs)
print('===============================================================')
print('pre-trained model loaded, training started:')
print('===============================================================')
trainer.train()
print('===============================================================')
print('train success.')
print('===============================================================')
for i in range(cfg.train.max_epochs):
eval_results = trainer.evaluate(f'{WORK_DIR}/epoch_{i+1}.pth')
print(f'epoch {i} evaluation result:')
print(eval_results)
print('===============================================================')
print('evaluate success')
print('===============================================================')
模型训练数据有限,不能包含所有行业,因此在特定行业数据上,效果可能存在一定偏差。
数据集 | Avg | ATEC | bq_corpus | ChineseSTS | LCQMC | paws-x-zh |
---|---|---|---|---|---|---|
Accuracy | 0.8831 | 0.8662 | 0.8668 | 0.9736 | 0.8959 | 0.8629 |
@article{wang2019structbert,
title={Structbert: Incorporating language structures into pre-training for deep language understanding},
author={Wang, Wei and Bi, Bin and Yan, Ming and Wu, Chen and Bao, Zuyi and Xia, Jiangnan and Peng, Liwei and Si, Luo},
journal={arXiv preprint arXiv:1908.04577},
year={2019}
}