生成扩散模型高效调优-Swift-Prompt
  • 模型资讯
  • 模型资料

生成扩散模型高效调优-Prompt

基于Swift库对Stable Diffusion模型进行高效调优。通过Prompt-Tuner模块,在训练时只需训练少规模的参数,即可高效地定制专属于你的场景的"Stable Diffusion"模型!

模型描述

本模型基础的Diffusion Model采用Stable-Diffusion-v1-5预训练模型,Tuner训练模块的参数量小于总模型的0.1%。

期望模型使用方式以及适用范围

如何使用

基于 ModelScope 框架,通过调用预定义的 Pipeline 可实现快速调用。

代码范例

from modelscope.pipelines import pipeline
sd_pipeline = pipeline('efficient-diffusion-tuning', 
                            'damo/multi-modal_efficient-diffusion-tuning-swift-prompt')
inputs = {'prompt': 'a street scene with a cafe and a restaurant sign in anime style'}
result = sd_pipeline(inputs)
print(f'Output: {result}.')

训练数据介绍

模型训练及验证

以下过程基于上述数据集,实现了SD-Tuner模型的训练及验证过程。

import os
import tempfile
import cv2
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode

model_id = 'damo/multi-modal_efficient-diffusion-tuning-swift-prompt'

# 数据准备
train_dataset = MsDataset.load('style_custom_dataset',
                                namespace='damo',
                                split='train',
                                subset_name='Anime').remap_columns({"Image:FILE": "target:FILE"})

tmp_dir = tempfile.TemporaryDirectory().name
max_epochs = 1
lr = 0.0001

def cfg_modify_fn(cfg):
    cfg.train.max_epochs = max_epochs
    cfg.train.lr_scheduler.T_max = max_epochs
    cfg.train.optimizer.lr = lr
    cfg.model.inference = False
    cfg.model.pretrained_tuner = None
    return cfg

kwargs = dict(
    model=model_id,
    work_dir=tmp_dir,
    train_dataset=train_dataset,
    cfg_modify_fn=cfg_modify_fn)

# 模型训练
trainer = build_trainer(name="trainer", default_args=kwargs)
trainer.train()
print(f'Efficient-diffusion-tuning-swift-prompt train.')

# 训练后推理过程
work_dir = os.path.join(tmp_dir, 'output')
inputs = {'prompt': 'a street scene with a cafe and a restaurant sign in anime style'}
pipe = pipeline(task=Tasks.efficient_diffusion_tuning, model=work_dir)
outputs = pipe(inputs)
cv2.imwrite('result.png', outputs['output_imgs'][0])

相关论文以及引用信息

如果该模型对您有所帮助,请引用下面的相关的论文:

@inproceedings{jia2022vpt,
  title={Visual Prompt Tuning},
  author={Jia, Menglin and Tang, Luming and Chen, Bor-Chun and Cardie, Claire and Belongie, Serge and Hariharan, Bharath and Lim, Ser-Nam},
  booktitle=ECCV,
  year={2022}
}
@misc{rombach2021highresolution,
  title={High-Resolution Image Synthesis with Latent Diffusion Models}, 
  author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
  year={2021}
}