dalle:zero-shot text-to-image generation
以下是DALL·E生成模型的总结:
DALL·E模型概述
DALL·E是一种结合了离散VAE和扩散模型的生成模型,能够从输入的文字描述中生成高质量的图像,并能根据提供的上下文信息进行超现实想象。模型结构
DALLE架构
- 编码器:将输入的文字描述转换为嵌入表示。
- 解码器:将嵌入表示转化为图像特征并生成最终图像。
- 多阶段训练:第一阶段仅学习dVAE编码部分;第二阶段引入Transformer进行自回归训练。
关键组件- VAE:用于从图片中提取潜在代码。
- DALLE:一种离散扩散模型。
- CLIP:用于对齐文字描述与图片特征。
工作流程- 文本 → 编码器 → VAE编码 → DALLE解码 → 图片输出。
- 图片输入可选,则先通过VAE提取特征再进行后续处理。
实现细节- 使用了Gumbel softmax技巧解决离散输出问题。
- 解码时采用多轮采样以提高质量。
- 通过预训练后的DiscreteVAE作为基础架构。
实现特点
结合了离散VAE与扩散模型的优势。
支持超现实想象功能(如提供多个不同结果)。
使用改进的自注意力机制(如GEGLU)提升性能。总结
DALL·E通过整合离散VAE与扩散模型,在保持高质量的同时实现了对超现实想象的支持,并展示了强大的上下文理解能力。
DALL·E—从文本到图像,超现实主义的图像生成器 - 知乎欢迎关注Smarter, 构建CV世界观超现实主义的独特之处在于其将梦幻与现实完美融合以呈现绝对的真实,在这一领域中,DALL·E作为一款突破性的图像生成工具,能够直接通过文本描述创造出具有超现实主义风格的画面,这不仅赋予了机器人的视觉表达能力,更使其具备了与顶尖艺术家及设计师相当的创作能力。

从实现角度来看,DALL·E模型的表现仍存在诸多疑问.官方尚未公布其具体的算法细节,GitHub上开源的代码库也仅提供了dVAE这一核心组件,目前仅实现了原始架构的一半性能.相比之下,Hugging Face相关的开源项目虽然在一定程度上推进了相关技术的发展,但与DALL·E的实际能力仍有较大差距.

从技术实现的角度分析DALL-E模型的发展历程 - 知乎
最近,DALLE与VQGAN凭借其强大的图像生成能力引发了广泛关注,其中,DALLE可通过输入一段描述性文字,创造出具有高度艺术性的图片,VQGAN则能生成分辨率极高的百万像素级图像.这些先进的生成模型均源自于对VAE及VQVAE核心理念的继承与创新.

这篇文章主要探讨了基于VAE与VQVAE的改进版本DALLE的工作原理及其在图像生成领域的应用前景。作为一种分阶段设计的算法,DALLE需训练三个关键组件:dvae、dalle以及与文本处理相关的辅助模块。具体而言,dvae内部采用编码器架构来提取图像特征,而dalle则是一个结合了图像与文本双重表征的自回归语言模型,其核心机制借鉴于transformer架构的设计理念。值得注意的是,dalle并非简单的拼接text与image特征后套用transformer框架,而是遵循了一整套系统化的流程:输入文本信息后依次生成图像特征并完成解码过程,最终通过CLIP模型对候选图像进行筛选以获得最优结果这一流程完全不同于传统的代理任务设定,而是建立在一个完整的自监督学习框架之上。
训练阶段:
- 独立完成dVAE的训练工作(最终获得encoder、visual codebook以及decoder);
- 运用Transformer架构进行处理,在对text和image分别进行编码后整合成统一表示,并模仿GPT-3模型构建一个自回归的语言模型。
推理阶段:
将输入划分为两种类型:第一种仅包含text;第二种为text与image结合输入。
对于仅包含text的情况,在对text进行编码后通过transformer架构进行自回归解码生成image tokens;随后将这些tokens通过dVAE的codebook映射至latent空间以获得latent code,并通过dVAE的decoder将其解码生成图片。
当同时提供text和image时,在生成image tokens的过程中引入前缀信息(具体代码中采用前面14×32个token作为默认prefix),这种设计有助于提高生成效果的可控性;其余处理流程与仅包含text的情况一致。
最后一步,请评估一下模型性能的表现吧!因为评估过程涉及多轮交互对话,并非单次输入就能完成任务。每一次对话都会根据前一次的结果不断优化模型输出质量(这可能涉及到多次迭代调参),最终才能得到满意的结果输出。为了实现这一目标就需要设计一个完整的评估指标体系(包括准确率、召回率等多个关键指标),这样才能全面衡量模型的实际性能表现
1.Introduction
用GAN代替VAE能够显著提升图像的真实度,在生成模型领域中(尤其是涉及超分辨率等场景),通常的做法是将GAN应用于解码器模块中。这种做法之所以被广泛采用的原因在于GAN能够生成更加逼真的图像内容。然而这一方法并非完美无缺,在实际应用中存在一些局限性:潜在的问题在于训练样本可能会受到严重伪影的影响(例如目标变形、不合理物体放置或前景与背景元素过度混合等)。此外,在研究超分辨率领域时发现:传统的CNN解码器往往会生成图像呈现出过平滑的现象(即缺乏锐利的边缘),但同时GAN方法可能会输出与原始图像无关的一些特征内容。
2.method
stage 1: 对输入图像经过dVAE模型处理后,在空间分辨率上从(原始) 256× 二维矩阵降维至 $ 的二维矩阵,并生成每个像素位置上的嵌码。具体而言,在此过程中, 每个像素位置处嵌码的空间维度大小设定为819\texttt{D}个可能取值, 这一设定确保了后续潜在空间中嵌入表示具有可学习性。随后通过该区域对应的编码器输出结果在潜在空间中定位并采样对应的潜在向量, 最终完成对原始图像细节信息的有效重建和压缩存储。
第二阶段:通过BPE encoder对原始文本进行编码操作,在生成最多256个文本tokens的同时补充不足部分(即不足部分通过pad技术补齐)。随后将生成的256个文本tokens与预处理后的1024个图像tokens结合形成一个1280维的特征向量,并将此特征向量输入到transformer模型中进行自回归预测。

dVAE是一种变分自编码器(VAE),与传统的VAE不同在于其对数据分布的建模方式。相比于VAE中基于均值-方差表征高斯分布的方法以及通过KL散度实现先验与后验之间分布匹配的技术手段,VQVAE则采用了引入后验分布的方式进行约束,并通过重参数化技术从均值-方差表征的高斯分布中提取潜在变量进行解码过程。具体而言,VQVAE通过编码器网络将输入数据映射至中间编码表示,随后利用最近邻搜索机制将中间编码映射到代码书中的k个基向量之一,再通过解码器网络对潜在代码进行重建。在这一过程中,最近邻搜索采用了argmax运算来确定代码书中的索引位置,但由于其不可导特性而无法用于优化过程,dALLE则采用Gumbel-Softmax技巧巧妙地解决了这一问题:将不可导的argmax运算近似为可导的softmax操作,从而实现了对潜在变量的有效优化。

第一段描述了一个编码器模块,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是编码器模块本身,在KL空间中编码的第一部分对应的是...
2.1 stage 1:learning the visual codebook
kl weight=6.6,K=8196
2.2 stage 2:learning the prior

这部分属于DALLE模型的技术构成阶段,在其发展过程中经历了重要的技术演进。具体而言,在DALLE 2版本中实现了对扩散模型技术的全面应用。其核心组件——自回归transformer模型,在这种架构下承担着将经过BPE编码器处理后的文本信息与经过dVAE编码器处理后的图像特征进行融合的任务。从损失函数的设计角度来看,在这一模块中所采用的方法与CLIP模型具有高度的一致性,并且在理论基础上有深刻的关联性
2.3 推理
在推理过程中采用dVAE的解码机制获取初步结果,在此基础上通过CLIP技术筛选出最优候选供后续处理使用。
2.4 data collection
120亿的参数量,3.3m对text-image对。
3.代码
VAE:
vae = DiscreteVAE(
image_size=256,
num_layers=3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens=8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim=512, # codebook dimension
hidden_dim=64, # hidden dimension
num_resnet_blocks=1, # number of resnet blocks
temperature=0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
straight_through=False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)
AI助手
原始图像张量4\times 3\times 256\times 256经过归一化操作得到输入特征图4\times 8196\times 32\times 32;随后通过应用Gumbel-Softmax函数获得soft_one_hot表示;样本采样采用爱因斯坦求和运算符进行计算得到采样索引张量4\times 512\times 32\times 32;最终输出图像张量4\times 3\times 256\times 256由解码器生成。
DiscreteVAE(
(codebook): Embedding(8192, 512)
(encoder): Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(1): Sequential(
(0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(2): Sequential(
(0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(3): ResBlock(
(net): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(4): Conv2d(64, 8192, kernel_size=(1, 1), stride=(1, 1))
)
(decoder): Sequential(
(0): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
(1): ResBlock(
(net): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(3): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(4): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(5): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
AI助手
dalle:
dalle = DALLE(
dim=1024,
vae=vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens=10000, # vocab size for text
text_seq_len=256, # text sequence length
depth=12, # should aim to be 64
heads=16, # attention heads
dim_head=64, # attention head dimension
attn_dropout=0.1, # attention dropout
ff_dropout=0.1 # feedforward dropout
)
AI助手
image:4,3,256,256/text:4,256->text_range:256,text_seq_len:1280,num_image_tokens:8192,num_text_tokens:10256->text:4,256->text=F.pad:4,257->tokens=text_emb(text):4,257,1024->image=vae.get_codebook_indices(image)->logits=self(image):4,8196,32,32->codebook_indices=logits.argmax:4,1024->image_emb=image_emb(image):4,1024,1024->tokens:4,1281,1024->out=self.transformers(tokens:4,1280,1024):4,1280,1024->logits:4,1280,18448->offsetted_image:4,1028,text:4,257,labels:4,1280->logits:4,18448,1280
DALLE(
(vae): DiscreteVAE(
(codebook): Embedding(8192, 1024)
(encoder): Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(1): Sequential(
(0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(2): Sequential(
(0): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(3): ResBlock(
(net): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(4): Conv2d(64, 8192, kernel_size=(1, 1), stride=(1, 1))
)
(decoder): Sequential(
(0): Conv2d(1024, 64, kernel_size=(1, 1), stride=(1, 1))
(1): ResBlock(
(net): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(3): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(4): Sequential(
(0): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): ReLU()
)
(5): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
(transformer): Transformer(
(layers): SequentialSequence(
(layers): ModuleList(
(0): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(1): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(2): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(3): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(4): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(5): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(6): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(7): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(8): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(9): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(10): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
(11): ModuleList(
(0): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): CachedAs(
(fn): Attention(
(to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
(to_out): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.1, inplace=False)
)
)
)
)
)
)
)
(1): LayerScale(
(fn): PreNorm(
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(norm_out): Identity()
(fn): CachedAs(
(fn): PreShiftToken(
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=1024, out_features=8192, bias=True)
(1): GEGLU()
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
)
)
)
)
)
)
(to_logits): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=18448, bias=True)
)
(text_emb): Embedding(10256, 1024)
(image_emb): Embedding(8192, 1024)
)
AI助手
