基于此模型在QBQTC数据集上训练得到MaSTS中文文本相似度-CLUE语义匹配模型。集成在QBQTC数据集上训练得到的相似度匹配模型,在CLUE语义匹配榜上获得了第一名的成绩。
使用教程请参考 https://developer.aliyun.com/article/1128425 。
模型通过对语义匹配任务改进的掩码策略进行无监督预训练。按照BERT文本对分类的方式,在QBQTC数据集上进行微调。
模型主要用于在QBQTC数据集上进行微调。
模型训练数据有限,在其他数据上效果可能存在一定偏差。
请参考ModelScope环境安装。
import os.path as osp
from modelscope.trainers import build_trainer
from modelscope.msdatasets import MsDataset
from modelscope.utils.hub import read_config
from modelscope.utils.constant import Tasks
model_id = 'damo/nlp_masts_backbone_clue_chinese-large'
dataset_id = 'QBQTC'
WORK_DIR = 'workspace'
# 通过这个方法修改cfg
def cfg_modify_fn(cfg):
# 将backbone模型加载到句子相似度的模型类中
cfg.task = Tasks.sentence_similarity
# 使用句子相似度的预处理器
cfg['preprocessor'] = {
'train': {
'type': 'sen-sim-tokenizer',
# 第一个字段的key
'first_sequence': 'query',
# 第二个字段的key
'second_sequence': 'title',
# label的key
'label': 'label',
'mode': 'train',
},
'val': {
'type': 'sen-sim-tokenizer',
# 第一个字段的key
'first_sequence': 'query',
# 第二个字段的key
'second_sequence': 'title',
# label的key
'label': 'label',
'mode': 'eval',
},
}
# lr_scheduler的配置
return cfg
train_dataset = MsDataset.load(dataset_id, namespace='damo', subset_name='default', split='train', keep_default_na=False)
eval_dataset = MsDataset.load(dataset_id, namespace='damo', subset_name='public', split='test', keep_default_na=False)
kwargs = dict(
model=model_id,
model_revision='v1.0.2',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
cfg_modify_fn=cfg_modify_fn,
)
trainer = build_trainer(default_args=kwargs)
print('===============================================================')
print('pre-trained model loaded, training started:')
print('===============================================================')
trainer.train()
print('===============================================================')
print('train success.')
print('===============================================================')
for i in range(cfg.train.max_epochs):
eval_results = trainer.evaluate(f'{WORK_DIR}/epoch_{i+1}.pth')
print(f'epoch {i} evaluation result:')
print(eval_results)
print('===============================================================')
print('evaluate success')
print('===============================================================')