PALM预训练语言生成模型是针对实际场景中常见的文本生成需求所设计的一个模型。模型利用大量无监督数据,通过结合自编码和自回归任务进行预训练,更贴合下游生成任务所同时需要的理解和生成能力。
针对实际场景中常见的文本生成需求,达摩院自主研发了PALM预训练语言生成模型。该模型通过在大规模文本上预训练得到,可作为下游自然语言生成任务的模型参数输入,以帮助提升下游任务的生成效果。PALM具有以下特点:
本模型是PALM通用预训练生成模型,在中文LCSTS数据集上进行finetune得到的文本摘要生成模型。PALM模型介绍,详见:PALM:Pre-training an Autoencoding&Autoregressive Language Model for Context-conditioned Generation
本模型主要用于给输入文档生成摘要内容。用户可以自行尝试各种输入文档。具体调用方式请参考代码示例。
在安装完成ModelScope-library之后即可使用text-generation的能力
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.outputs import OutputKeys
input = '昨天起,上海地铁3号线长江南路站、殷高西路站、江湾镇站三站进一步限流。体验发现,高峰时段排队5分钟能进站;不少乘客选择提前起床,“现在提前10到20分钟起床,即便限流也不会影响上班”。被限流的XDJMS,你们提前多久?新民网'
text_summary = pipeline(Tasks.text_generation, model='damo/nlp_palm2.0_text-generation_chinese-base')
result = text_summary(input)
print('输入文本:\n' + input + '\n')
print('文本摘要结果:\n' + result[OutputKeys.TEXT])
模型在新闻相关数据集上训练,在新闻等类似文章上摘要生成性能较好,其他垂直领域效果可能会有所下降。
本模型中文训练数据集是LCSTS,数据集240w左右, 具体数据可以下载
用户可以基于这个摘要模型在自己的摘要数据上做continue train,如果不是摘要任务,请前往通用PALM生成模型进行训练:PALM 2.0预训练生成模型-中文-base
模型采用2张NVIDIA V100机器训练, 超参设置如下:
train_epochs=15
max_sequence_length=128
batch_size=8
learning_rate=1e-3
optimizer=AdamW
import tempfile
from modelscope.msdatasets import MsDataset
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
# lcsts_test_set 为示例数据集,用户也可以使用自己的数据集进行训练
dataset_dict = MsDataset.load('lcsts_test_set', namespace='DAMO_NLP')
# 训练数据的输入出均为文本,需要将数据集预处理为输入为 src_txt,输出为 tgt_txt 的格式:
train_dataset = dataset_dict['train'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
eval_dataset = dataset_dict['test'].remap_columns({'text1': 'src_txt', 'text2': 'tgt_txt'})
# 用户自己数据集构造
# train_dataset_dict = {"src_txt": ["text1", "text2"], "tgt_txt": ["text1", "text2"]}
# eval_dataset_dict = {"src_txt": ["text1", "text2"], "tgt_txt": ["text1", "text2"]}
# train_dataset = MsDataset(Dataset.from_dict(train_dataset_dict))
# eval_dataset = MsDataset(Dataset.from_dict(eval_dataset_dict))
num_warmup_steps = 500
def noam_lambda(current_step: int):
current_step += 1
return min(current_step**(-0.5),
current_step * num_warmup_steps**(-1.5))
# 可以在代码修改 configuration 的配置
def cfg_modify_fn(cfg):
cfg.preprocessor.sequence_length = 128
cfg.train.lr_scheduler = {
'type': 'LambdaLR',
'lr_lambda': noam_lambda,
'options': {
'by_epoch': False
}
}
cfg.train.optimizer = {
"type": "AdamW",
"lr": 1e-3,
"options": {}
}
cfg.train.max_epochs = 15
cfg.train.dataloader = {
"batch_size_per_gpu": 8,
"workers_per_gpu": 1
}
return cfg
kwargs = dict(
model='damo/nlp_palm2.0_pretrained_chinese-base',
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=tempfile.TemporaryDirectory().name,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train()
模型在LCSTS测试数据评估结果
Rouge-1 | Rouge-2 | Rouge-L |
---|---|---|
43.31 | 28.81 | 39.78 |
如果我们的模型对您有帮助,请您引用我们的文章:
@inproceedings{bi-etal-2020-palm,
title = "{PALM}: Pre-training an Autoencoding & Autoregressive Language Model for Context-conditioned Generation",
author = "Bi, Bin and
Li, Chenliang and
Wu, Chen and
Yan, Ming and
Wang, Wei and
Huang, Songfang and
Huang, Fei and
Si, Luo",
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
month = nov,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2020.emnlp-main.700",
doi = "10.18653/v1/2020.emnlp-main.700",
pages = "8681--8691"}