RLEG:Representation Learning with Embedding Generation,将生成模型与表征模型结合在一起,
利用预训练的Diffusion生成模型,在特征空间在线生成更多的图文特征样本,指导和增强表征学习过程,
训练完成后在Inference阶段仅保留表征模型部分,
得到的多模态表征在图像分类、目标检测、语义分割、图文检索、图像生成等多个下游任务上得到性能提升.
Model | layers | width | heads | embedding dim |
---|---|---|---|---|
Vision Transformer | 24 | 1024 | 16 | 768 |
Text Transformer | 12 | 768 | 12 | 768 |
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from PIL import Image
import requests
model = Model.from_pretrained('damo/multi-modal_rleg-vit-large-patch14')
p = pipeline(task=Tasks.generative_multi_modal_embedding, model=model)
url = 'http://clip-multimodal.oss-cn-beijing.aliyuncs.com/lingchen/demo/dogs.jpg'
image = Image.open(requests.get(url, stream=True).raw)
text = 'dogs playing in the grass'
img_embedding = p.forward({'image': image})['img_embedding']
print('image embedding: {}'.format(img_embedding))
text_embedding = p.forward({'text': text})['text_embedding']
print('text embedding: {}'.format(text_embedding))
模型在数据集上训练,有可能产生一些偏差,请用户自行评测后决定如何使用。
训练数据共约4亿公开图文数据集,包含CC3M/CC12M/LAION-400M/YFCC100M公开数据集。
–图像输入:RandomResizedCrop到224*224
,随机水平翻转
–文本输入:最多保留77个token
初始LR为0.001,每30000个iteration之后减小为1/5,共训练90000个iteration。
该模型在3个公开图像识别或图文检索数据集上进行了zero-shot评测,对比CLIP相同参数量模型,达到SOTA效果,详细Top1识别或检索准确率如下表所示:
ImageNet-Top1 | ImageNet-Top5 | Flickr30K-I2T | Flickr30K-T2I | MSCOCO-I2T | MSCOCO-T2I | |
---|---|---|---|---|---|---|
CLIP (Large) | 75.3 | 94.6 | 85.7 | 64.6 | 57.3 | 36.1 |
RLEG (Large) | 80.9 | 96.6 | 90.5 | 75.4 | 62.1 | 45.4 |