Advertisement

深度学习论文: Learning Transferable Visual Models From Natural Language Supervision

阅读量:

深度学习论文: Learning Transferable Visual Models From Natural Language Supervision
Learning Transferable Visual Models From Natural Language Supervision
PDF: https://arxiv.org/pdf/2103.00020.pdf
官方代码: https://github.com/OpenAI/CLIP
PyTorch代码: https://github.com/shanglianlm0525/CvPytorch
PyTorch代码: https://github.com/shanglianlm0525/PyTorch-Networks

1 概述

CLIP(对比性语言-图像预训练)是一个在各种(图像,文本)对上进行训练的神经网络。它可以通过自然语言指令,在给定图像的情况下预测最相关的文本片段,而不是直接为任务进行优化,类似于GPT-2和GPT-3的零样本能力。发现CLIP在ImageNet的“零样本”上与原始的ResNet50的性能相匹配,而且没有使用任何原始的128万个标记示例,克服了计算机视觉中的几个重要挑战。

2 CLIP (Contrastive Language-Image Pre-Training)

首先是构建CLIP,CLIP实际上是一个预训练模型,包括文本编辑和图像编辑器两部分,分别计算文本向量和图像向量的相似度,以预测它们是否为一对,如图1所示。CLIP将图像和文本先分别输入一个图像编码器image_encoder和一个文本编码器text_encoder,得到图像和文本的向量表示 I_{f}T_{f} 。然后将图像和文本的向量表示映射到一个联合多通道空间,得到新的可直接进行比较的图像和文本的向量表示 I_{e}T_{e} 。然后计算图像向量和文本向量之间的cosine相似度。最后,对比学习的目标函数就是让正样本对的相似度较高,负样本对的相似度较低。矩阵中非对角线上的元素都是负样本,n个正样本, n^{2}-n个负样本,有了正负样本,模型就可以通过对比学习的方式去训练了,不需要任何手工标注。但是这种无监督的训练方式,是需要大量的训练数据的。
在这里插入图片描述

CLIP核心实现的伪代码:
在这里插入图片描述
推理代码:

复制代码
    import torch
    import clip
    from PIL import Image
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    
    image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
    text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
    
    with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    
    print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

在训练过程中,使用了5个不同的ResNet模型:ResNet-50,ResNet-101,4xResNet-50,16xResNet-50和64xResNet-50,以及3个ViT模型:ViT-B/32,ViT-B/16和ViT-L/14。训练时的迭代轮数为32,优化器使用Adam,并采用解耦权重衰减正则化技术,调度器采用余弦调度。初始超参数通过网格搜索确定。此外,训练过程中使用了半精度技术,并在训练完成后,再对稍大一些的336像素分辨率进行一个额外的训练轮数。最终,选择在336像素分辨率上表现最好的ViT-L/14作为后文所指的CLIP模型。

3 Experiments

3-1 Zero-Shot Transfer

CLIP实现零样本推理的方法是通过预训练获得文本和图像的特征,而没有分类头。为了解决这个问题,作者提出了一种利用自然语言的方法,即prompt template(提示模板)。对于ImageNet的类别,首先将其转化为句子的形式,例如"A photo of a {object}"。由于ImageNet有1000个类别,因此会生成1000个句子。然后,通过之前预训练好的文本编码器,这1000个句子可以得到1000个文本特征。

虽然可以直接使用类别单词来提取文本特征,但是在模型预训练时,与图像配对的都是句子,因此在推理时使用单词效果会下降。因此,使用句子作为提示更为有效。将待分类的图像送入图像编码器,得到其特征。然后,将图像特征与1000个文本特征计算余弦相似度,并选择最相似的文本特征对应的句子,从而完成分类任务。这种方法不仅限于这1000个类别,任何类别都可以使用。

通过这种方式,CLIP完全摆脱了分类标签的限制,无论是在训练还是推理过程中,都不需要预先定义好的标签列表。
在这里插入图片描述

复制代码
    import os
    import clip
    import torch
    from torchvision.datasets import CIFAR100
    
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load('ViT-B/32', device)
    
    # Download the dataset
    cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
    
    # Prepare the inputs
    image, class_id = cifar100[3637]
    image_input = preprocess(image).unsqueeze(0).to(device)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
    
    # Calculate features
    with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)
    
    # Pick the top 5 most similar labels for the image
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(5)
    
    # Print the result
    print("\nTop predictions:\n")
    for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3-2 Representation Learning

Fine-tuning 与 linear probe

  • Linear probe:把一个训练好的模型冻结住,只训练最后一层的分类头去做分类任务。
  • Fine tune:对整个模型参数进行端到端的训练

CLIP 的作者选择了linear probe的方法进行预训练模型在其他数据集上表现的对比。
在这里插入图片描述

复制代码
    import os
    import clip
    import torch
    
    import numpy as np
    from sklearn.linear_model import LogisticRegression
    from torch.utils.data import DataLoader
    from torchvision.datasets import CIFAR100
    from tqdm import tqdm
    
    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load('ViT-B/32', device)
    
    # Load the dataset
    root = os.path.expanduser("~/.cache")
    train = CIFAR100(root, download=True, train=True, transform=preprocess)
    test = CIFAR100(root, download=True, train=False, transform=preprocess)
    
    
    def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))
    
            all_features.append(features)
            all_labels.append(labels)
    
    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
    
    # Calculate the image features
    train_features, train_labels = get_features(train)
    test_features, test_labels = get_features(test)
    
    # Perform logistic regression
    classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
    classifier.fit(train_features, train_labels)
    
    # Evaluate using the logistic regression classifier
    predictions = classifier.predict(test_features)
    accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
    print(f"Accuracy = {accuracy:.3f}")
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

全部评论 (0)

还没有任何评论哟~