MaSTS中文文本相似度-CLUE语义匹配模型是基于MaSTS预训练模型-CLUE语义匹配,在QBQTC数据集上训练得到的相似度匹配模型。通过集成此模型在CLUE语义匹配榜上获得了第一名的成绩。
使用教程请参考 https://developer.aliyun.com/article/1128425 和Jupyter Notebooktutorial.ipynb
。
模型按照BERT文本对分类的方式,在QBQTC数据集上进行微调。
输入形如(文本A,文本B)的文本对数据,模型会给出该文本对相关性的标签(“0”,“1”,“2”)以及相应的概率。相关性的含义:0,相关程度差;1,有一定相关性;2,非常相关。数字越大相关性越高。
模型训练数据有限,在其他数据上效果可能存在一定偏差。
请参考ModelScope环境安装。
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
similarity_pipeline = pipeline(Tasks.sentence_similarity, 'damo/nlp_masts_sentence-similarity_clue_chinese-large', model_revision='v1.0.0')
similarity_pipeline(input=('小孩咳嗽感冒', '小孩感冒过后久咳嗽该吃什么药育儿问答宝宝树'))
import os.path as osp
from modelscope.trainers import build_trainer
from modelscope.msdatasets import MsDataset
from modelscope.utils.hub import read_config
model_id = 'damo/nlp_masts_backbone_clue_chinese-large'
dataset_id = 'QBQTC'
WORK_DIR = 'workspace'
cfg = read_config(model_id, revision='v1.0.0')
cfg.train.work_dir = WORK_DIR
cfg_file = osp.join(WORK_DIR, 'train_config.json')
cfg.dump(cfg_file)
train_dataset = MsDataset.load(dataset_id, namespace='damo', subset_name='default', split='train', keep_default_na=False)
eval_dataset = MsDataset.load(dataset_id, namespace='damo', subset_name='public', split='test', keep_default_na=False)
kwargs = dict(
model=model_id,
model_revision='v1.0.0',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
cfg_file=cfg_file,
)
trainer = build_trainer(default_args=kwargs)
print('===============================================================')
print('pre-trained model loaded, training started:')
print('===============================================================')
trainer.train()
print('===============================================================')
print('train success.')
print('===============================================================')
for i in range(cfg.train.max_epochs):
eval_results = trainer.evaluate(f'{WORK_DIR}/epoch_{i+1}.pth')
print(f'epoch {i} evaluation result:')
print(eval_results)
print('===============================================================')
print('evaluate success')
print('===============================================================')
Dataset | Marco F1 | Accuracy |
---|---|---|
公开测试集(test_public) | 74.1 | 79.7 |