Advertisement

【Diffusion 视频生成】Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation

阅读量:

这篇文章介绍了Tune-A-Video方法,该方法使用扩散模型进行文本到视频的生成,特别是针对One-Shot视频生成任务。论文提出了一种稀疏时空注意机制,仅访问第一个和前一个视频帧,并且只更新注意力块中的投影矩阵,以提高时间一致性。此外,作者还提出了一种改进后的反向过程,通过潜变量作为初始噪声,使生成的视频更加连贯。论文还提供了对应的代码实现,并在GitHub上公开了Tune-A-Video的GitHub仓库。该方法在扩散模型视频生成领域具有重要意义,展示了如何利用预训练的文本到图像模型进行高效视频生成。

[Diffusion Models系列文章:基础与应用]( "Diffusion Models系列文章:基础与应用)

引言: 在引言部分,Tune-A-Video提出了一种新的文本-视频任务,即One-Shot视频生成,其在视频对象编辑、背景编辑、风格转换以及可控生成等方面均展现了显著的效果。本文深入解析了Tune-A-Video的论文及其代码实现,旨在为从事扩散模型视频生成的读者提供有益的参考。

目录

贡献概述

方法详解

稀疏时空注意机制

反向过程的结构引导

论文和代码

代码解读

Transformer3DModel

注意力机制 BasicTransformerBlock

改进后的Unet结构UNet3DConditionModel

个人感悟


贡献概述

Tune-A-Video开发出了一个创新的文本-视频任务模块:One-Shot视频生成功能,该功能仅基于单一的文本-视频样本对进行训练以生成相应的文本-视频内容。

博主认为这篇论文的两个主要创新点:

  1. 提出了一种稀疏的时空注意机制:该机制仅关注初始帧和前一帧,同时仅更新注意力块中的投影矩阵。
  2. 在推理过程中,通过将潜变量作为初始噪声输入,能够生成更加连贯的视频内容。

作者概括本篇论文的四个主要贡献点:

一种单次视频调整的T2V生成方法被提出,旨在缓解使用大规模视频数据集进行训练所带来的负担。该方法首次采用预训练的T2I模型进行T2V生成,开创了这一领域的发展。作者开发出高效率的注意力调整和结构反转方法,明显提高了时间一致性。通过实证研究展示了我们方法的明显结果。

方法详解

稀疏时空注意机制

Tune-A-Video采用了伪三维卷积层,与前一研究工作相似。此外,为更好地保持帧间的空间一致性并降低计算开销,Tune-A-Video设计了一种稀疏因果注意力机制。具体而言,该机制通过将当前帧的查询矩阵与序列首帧及前一帧的键矩阵计算关联性,并从首帧及前一帧提取值,从而实现了自回归式长视频序列生成。

新的attention计算方法用公式表示为:

如图所示,在模型训练过程中,仅使用了ST-attention和Cross-attention。

W^Q

、T-attention的

W^Q
W^K
W^V

是可训练的。

反向过程的结构引导

在进行DDIM反向过程中,不依赖于文本条件,提取源视频V的潜在噪声。将该噪声作为DDIM采样过程的起点。

原论文的描述过于简洁,使得我对这一做法的理解不够深入。作者详细阐述了该方法如何有效解决循环中的视频呈现停滞状态的问题。

论文和代码

Tune-A-Video

GitHub - showlab/Tune-A-Video: ICCV 2023 Single-Click Tuning of Image Diffusion Models for Text-to-Video Generation (T2V)

代码解读

Transformer3DModel

该模型通过注意力机制处理了时间序列数据,实现了对序列信息的高效捕捉。该模型基于自回归架构,能够逐步生成目标序列,同时通过多头注意力机制捕获序列间的复杂关联。该模型利用位置编码和时序信息辅助注意力计算,从而实现了对长序列数据的有效建模。该模型通过自注意力机制实现了序列内信息的全局捕捉,同时通过时序自回归机制实现了序列间的有序建模。该模型基于预训练语言模型的自监督学习框架,通过最大化上下文之间的相关性,实现了对序列数据的深度理解。该模型通过自注意力机制实现了对序列数据的多尺度建模,同时通过时序建模技术提升了对动态变化的捕捉能力。该模型基于Transformer架构,通过并行计算实现了对序列数据的高效处理,同时通过多头注意力机制提升了模型的表达能力。该模型通过自注意力机制实现了对序列数据的全局建模,同时通过时序建模技术提升了对序列生成过程的控制能力。该模型基于自回归模型的自监督学习框架,通过最大化上下文之间的相关性,实现了对序列数据的深度学习。该模型通过自注意力机制实现了对序列数据的多尺度建模,同时通过时序建模技术提升了对序列生成过程的控制能力。

在视频处理中,视频中的注意力结构需要根据视频长度定制不同的隐藏状态,这是一项关键的技术要求。

复制代码
     hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")

    
     encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)

以BasicTransformerBlock为核心的结构,在其之后,添加了标准的input layer、norm和output layer。

复制代码
 class Transformer3DModel(ModelMixin, ConfigMixin):

    
     @register_to_config
    
     def __init__(
    
     self,
    
     num_attention_heads: int = 16,
    
     attention_head_dim: int = 88,
    
     in_channels: Optional[int] = None,
    
     num_layers: int = 1,
    
     dropout: float = 0.0,
    
     norm_num_groups: int = 32,
    
     cross_attention_dim: Optional[int] = None,
    
     attention_bias: bool = False,
    
     activation_fn: str = "geglu",
    
     num_embeds_ada_norm: Optional[int] = None,
    
     use_linear_projection: bool = False,
    
     only_cross_attention: bool = False,
    
     upcast_attention: bool = False,
    
     ):
    
     super().__init__()
    
     self.use_linear_projection = use_linear_projection
    
     self.num_attention_heads = num_attention_heads
    
     self.attention_head_dim = attention_head_dim
    
     inner_dim = num_attention_heads * attention_head_dim
    
  
    
     # Define input layers
    
     self.in_channels = in_channels
    
  
    
     self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
    
     if use_linear_projection:
    
         self.proj_in = nn.Linear(in_channels, inner_dim)
    
     else:
    
         self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
    
  
    
     # Define transformers blocks
    
     self.transformer_blocks = nn.ModuleList(
    
         [
    
             BasicTransformerBlock(
    
                 inner_dim,
    
                 num_attention_heads,
    
                 attention_head_dim,
    
                 dropout=dropout,
    
                 cross_attention_dim=cross_attention_dim,
    
                 activation_fn=activation_fn,
    
                 num_embeds_ada_norm=num_embeds_ada_norm,
    
                 attention_bias=attention_bias,
    
                 only_cross_attention=only_cross_attention,
    
                 upcast_attention=upcast_attention,
    
             )
    
             for d in range(num_layers)
    
         ]
    
     )
    
  
    
     # 4. Define output layers
    
     if use_linear_projection:
    
         self.proj_out = nn.Linear(in_channels, inner_dim)
    
     else:
    
         self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
    
  
    
     def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
    
     # Input
    
     assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
    
     video_length = hidden_states.shape[2]
    
     hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
    
     encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
    
  
    
     batch, channel, height, weight = hidden_states.shape
    
     residual = hidden_states
    
  
    
     hidden_states = self.norm(hidden_states)
    
     if not self.use_linear_projection:
    
         hidden_states = self.proj_in(hidden_states)
    
         inner_dim = hidden_states.shape[1]
    
         hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
    
     else:
    
         inner_dim = hidden_states.shape[1]
    
         hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
    
         hidden_states = self.proj_in(hidden_states)
    
  
    
     # Blocks
    
     for block in self.transformer_blocks:
    
         hidden_states = block(
    
             hidden_states,
    
             encoder_hidden_states=encoder_hidden_states,
    
             timestep=timestep,
    
             video_length=video_length
    
         )
    
  
    
     # Output
    
     if not self.use_linear_projection:
    
         hidden_states = (
    
             hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
    
         )
    
         hidden_states = self.proj_out(hidden_states)
    
     else:
    
         hidden_states = self.proj_out(hidden_states)
    
         hidden_states = (
    
             hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
    
         )
    
  
    
     output = hidden_states + residual
    
  
    
     output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
    
     if not return_dict:
    
         return (output,)
    
  
    
     return Transformer3DModelOutput(sample=output)

注意力机制 BasicTransformerBlock

查看原文链接:https://github.com/showlab/Tune-A-Video/blob/5803c255ba25538a63eb2b9fcab610550e03f66c/tuneavideo/models/attention.py#L139

该系统的主干部分主要由下图红框框出的部分构成,具体来说,包括ST-Atten、Cross-Atten以及T-Atten三个模块。

建议与diffusers中的BasicTransformerBlock进行对比分析,其结构由自注意力机制、交叉注意力机制以及前馈神经网络组成。

在diffusers项目中,第60行的代码如下:

在diffusers项目中,第60行的代码如下:

ST-Atten继承于CrossAttention,但是需要针对video做一些shape变换。

复制代码
 class SparseCausalAttention(CrossAttention):

    
     def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
    
     batch_size, sequence_length, _ = hidden_states.shape
    
  
    
     encoder_hidden_states = encoder_hidden_states
    
  
    
     if self.group_norm is not None:
    
         hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
    
  
    
     query = self.to_q(hidden_states)
    
     dim = query.shape[-1]
    
     query = self.reshape_heads_to_batch_dim(query)
    
  
    
     if self.added_kv_proj_dim is not None:
    
         raise NotImplementedError
    
  
    
     encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
    
     key = self.to_k(encoder_hidden_states)
    
     value = self.to_v(encoder_hidden_states)
    
  
    
     former_frame_index = torch.arange(video_length) - 1
    
     former_frame_index[0] = 0
    
  
    
     key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
    
     key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
    
     key = rearrange(key, "b f d c -> (b f) d c")
    
  
    
     value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
    
     value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
    
     value = rearrange(value, "b f d c -> (b f) d c")
    
  
    
     key = self.reshape_heads_to_batch_dim(key)
    
     value = self.reshape_heads_to_batch_dim(value)
    
  
    
     if attention_mask is not None:
    
         if attention_mask.shape[-1] != query.shape[1]:
    
             target_length = query.shape[1]
    
             attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
    
             attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
    
  
    
     # attention, what we cannot get enough of
    
     if self._use_memory_efficient_attention_xformers:
    
         hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
    
         # Some versions of xformers return output in fp32, cast it back to the dtype of the input
    
         hidden_states = hidden_states.to(query.dtype)
    
     else:
    
         if self._slice_size is None or query.shape[0] // self._slice_size == 1:
    
             hidden_states = self._attention(query, key, value, attention_mask)
    
         else:
    
             hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
    
  
    
     # linear proj
    
     hidden_states = self.to_out[0](hidden_states)
    
  
    
     # dropout
    
     hidden_states = self.to_out[1](hidden_states)
    
     return hidden_states

Cross-Atten和T-Atten都是CrossAttention:

复制代码
     # Cross-Attn

    
     if cross_attention_dim is not None:
    
         self.attn2 = CrossAttention(
    
             query_dim=dim,
    
             cross_attention_dim=cross_attention_dim,
    
             heads=num_attention_heads,
    
             dim_head=attention_head_dim,
    
             dropout=dropout,
    
             bias=attention_bias,
    
             upcast_attention=upcast_attention,
    
         )
复制代码
     # Temp-Attn

    
     self.attn_temp = CrossAttention(
    
         query_dim=dim,
    
         heads=num_attention_heads,
    
         dim_head=attention_head_dim,
    
         dropout=dropout,
    
         bias=attention_bias,
    
         upcast_attention=upcast_attention,
    
     )

改进后的Unet结构UNet3DConditionModel

该GitHub存储位置提供了一个基于UNet架构的视频调谐模型实现,具体代码位于tuneavideo/models/unet.py文件的第37行。

主要是需要把self.conv1和self.conv2从Conv2d换成InflatedConv3d

复制代码
    self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
复制代码
    self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
复制代码
   if self.use_in_shortcut:

    
         self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

整体代码如下:

复制代码
 class ResnetBlock3D(nn.Module):

    
     def __init__(
    
     self,
    
     *,
    
     in_channels,
    
     out_channels=None,
    
     conv_shortcut=False,
    
     dropout=0.0,
    
     temb_channels=512,
    
     groups=32,
    
     groups_out=None,
    
     pre_norm=True,
    
     eps=1e-6,
    
     non_linearity="swish",
    
     time_embedding_norm="default",
    
     output_scale_factor=1.0,
    
     use_in_shortcut=None,
    
     ):
    
     super().__init__()
    
     self.pre_norm = pre_norm
    
     self.pre_norm = True
    
     self.in_channels = in_channels
    
     out_channels = in_channels if out_channels is None else out_channels
    
     self.out_channels = out_channels
    
     self.use_conv_shortcut = conv_shortcut
    
     self.time_embedding_norm = time_embedding_norm
    
     self.output_scale_factor = output_scale_factor
    
  
    
     if groups_out is None:
    
         groups_out = groups
    
  
    
     self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
    
  
    
     self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
    
  
    
     if temb_channels is not None:
    
         if self.time_embedding_norm == "default":
    
             time_emb_proj_out_channels = out_channels
    
         elif self.time_embedding_norm == "scale_shift":
    
             time_emb_proj_out_channels = out_channels 
    
         else:
    
             raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
    
  
    
         self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
    
     else:
    
         self.time_emb_proj = None
    
  
    
     self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
    
     self.dropout = torch.nn.Dropout(dropout)
    
     self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
    
  
    
     if non_linearity == "swish":
    
         self.nonlinearity = lambda x: F.silu(x)
    
     elif non_linearity == "mish":
    
         self.nonlinearity = Mish()
    
     elif non_linearity == "silu":
    
         self.nonlinearity = nn.SiLU()
    
  
    
     self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
    
  
    
     self.conv_shortcut = None
    
     if self.use_in_shortcut:
    
         self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
    
  
    
     def forward(self, input_tensor, temb):
    
     hidden_states = input_tensor
    
  
    
     hidden_states = self.norm1(hidden_states)
    
     hidden_states = self.nonlinearity(hidden_states)
    
  
    
     hidden_states = self.conv1(hidden_states)
    
  
    
     if temb is not None:
    
         temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
    
  
    
     if temb is not None and self.time_embedding_norm == "default":
    
         hidden_states = hidden_states + temb
    
  
    
     hidden_states = self.norm2(hidden_states)
    
  
    
     if temb is not None and self.time_embedding_norm == "scale_shift":
    
         scale, shift = torch.chunk(temb, 2, dim=1)
    
         hidden_states = hidden_states * (1 + scale) + shift
    
  
    
     hidden_states = self.nonlinearity(hidden_states)
    
  
    
     hidden_states = self.dropout(hidden_states)
    
     hidden_states = self.conv2(hidden_states)
    
  
    
     if self.conv_shortcut is not None:
    
         input_tensor = self.conv_shortcut(input_tensor)
    
  
    
     output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
    
  
    
     return output_tensor

个人感悟

该方法的局限性在于生成的动作仅限于与原始训练样本一致,然而,这本来就是为特定场景设计的方案,因此,这种限制是可以被接受的。

视频领域中的扩散模型数量稀少,这篇论文非常值得推荐。

3、代码是基于diffusers开发的,对开发者来说很友好!

4、反向过程的结构引导这部分没看明白,下次再研究研究

全部评论 (0)

还没有任何评论哟~