【论文笔记】A Neural Representation of Sketch Drawings
谷歌的研究成果采用了结合了seq2seq模型与VAE技术的方法,并用于生成手绘图像序列。
https://arxiv.org/pdf/1704.03477.pdf
本文主要是一篇论文摘要的翻译整理。
文章目录
-
- 1.Introduction
-
相关工作综述
-
方法
-
3.1 数据集介绍
-
3.2 基于草图的递归神经网络
-
无条件生成模型研究
-
训练过程优化策略
-
4.Experiments
-
- 基于条件的重建技术
-
- 潜在空间中的插值方法
-
- 图像素描类比(类比)
-
- 预测不完整素描的不同结局
-
5.Applications and Future Work;6.结论;略
-
附录
-
1.Introduction
- 生成模型的发展主要体现在GANs(生成对抗网络)、VI(变分推断)以及AR(自回归)等技术上。
- 当前这类方法主要应用于基于像素的空间数据(pixel images),而人类的认知过程则涉及向量序列的理解。
- 文本贡献方面,则包括以下几项:一是针对条件和非条件两种情况下的序列生成框架进行了研究;二是提出了名为"sketch-RNN"的新架构;三是通过引入潜在空间映射方法使训练更加稳定;四是进一步探讨了该方法在艺术风格迁移和图像恢复中的应用前景。
2.Related Work
- 在模仿绘画领域中,主要采用基于既定规则文件执行绘画的机器人,并非完全依赖生成模型;早期研究主要针对线条的使用多为HMM(隐马尔可夫模型)方法;近年来则发展出基于RNN(长短时记忆网络)以及混合密度网络等方法来处理连续数据序列,并扩展应用于汉字书写分析。
- 最新研究将Sequence-to-Sequence模型与VAE(变分自编码器)相结合以实现英语语言编码至潜在空间的研究。
- 最后部分介绍了若干现有的公开数据集。
3.方法
3.1 数据集
- 源自QuickDraw应用,在短时间内生成多种草图,并涵盖百多种类别;每个类别提供约7万份高质量训练数据,并包含约2千份验证样本和测试样本。
- 数据序列采用[dx, dy, p1, p2, p3]的形式组织,在此框架下:
- 其中dx和dy分别代表x和y方向的变化量;
- 其中p1指示继续绘制当前路径线段的状态,
- p2标识结束当前子序列的状态,
- 而p3则标志绘图过程的完成。

3.2 Sketch-RNN

- 主要架构基于序列到序列变分自编码器。
- 其中编码器采用双向RNN架构,并将草图作为输入数据。
潜在空间向量则被用作输出表示。
(通常采用常规变分自编码器(VAE)模型:首先计算均值\mu和标准差\sigma;然后通过正态分布采样得到隐变量向量z。)
- 其中编码器采用双向RNN架构,并将草图作为输入数据。
h_{{\Rightarrow}}={\rm encode}_{{⇒}}(S), h_{{←}}={\\rm encode}_{{←}}({S_{{rev}}} ), h=[h_ {{⇒ }}; h_ {{← }}]
\mu=W_ {{μ }}h+b _ {{μ }}, \\hat{\\ sigma }=W _ {{σ }}h+b _ {{σ }}, \\ sigma ={\\rm exp}\lgroup\\frac{\\ hat {\\ sigma }}{2}\\rgroup, z=μ+({\\boldsymbol}{σ}\odot {\\mathcal{N}(0,1)})
解码器采用自回归RNN架构,在生成过程中利用前一时刻的输出作为当前时刻的输入。基于此,在编码阶段已采用双向RNN模型,在解码阶段z经过tanh函数获得初始状态向量。数学公式表示为:[h_0;c_0]=\tanh(W_z z + b_z)
S_0定义为(0,0,1,0,0)
(dx, dy)基于M元正态分布的高斯混合模型(GMM)估计概率值,并将(q1, q2, q3)作为分类指标参与计算;同时M也构成一个类别分布,并在GMM中体现为各组分的混合比例。相应的概率密度函数表达式为:
p(\triangle x,\triangle y)=\sum_{j=1}^{M} \mathcal{N}(\triangle x,\triangle y | \mu_{x,j},\mu_{y,j},\sigma_{x,j},\sigma_{y,j},\rho_{xy,j}),
其中满足约束条件\sum_{j=1}^{M}\Pi_j=1
由此可见,在解码器模块中,
输出空间维度由输入特征图尺寸决定,
总共有5M + M + 3个神经元,
即总共有6M + 3维的空间维度。
x_i = [S_{i-1}; z], [h_i; c_i] = forward(x_i, [h_{i-1}; c_{i-1}]), y_i = W_y h_i + b_y, y_i ∈ ℝ^{6M+3}
其中,
(\widehat{Π}μ_xμ_y\widehat{σ}_x\widehat{σ}_y\rho_{xy})_1,…,(\widehat{Π}μ_xμ_y\widewidecheck{σ}_x\widewidecheck{σ}_y\rho_{xy})_M,\widecheck{q₁},\widecheck{q₂},\widecheck{q₃}与yᵢ对应。
为了确保标准差数值的非负性,在计算过程中我们采用exp函数将结果限制在正数范围内,并利用tanh函数将相关系数限定在-1至1之间。
各类别概率分布由以下公式给出:
q_k = \frac{\exp(\widehat{q}_k)}{\sum_{j=1}^{3} \exp(\widehat{q}_j)}, \quad k \in \{1, 2, 3\}
以及
\Pi_k = \frac{\exp(\widehat{\Pi}_k)}{\sum_{j=1}^{M} \exp(\widehat{\Pi}_j)}, \quad k \in \{1, ..., M\}
本研究存在p₁、p₂、p₃状态数据分布失衡的问题,在现有解决方案中普遍采用样本加权技术这一做法并不适用于多分类场景。为此我们提出了一种基于设定最大序列长度的新方案,在实际运行后将最终结果统一标记为(0,0,0,0,1)的形式
在训练过程中,在每一步骤中计算当前时间段的结果。而生成过程则通过将当前时间段的结果作为下一个时间段的输入来进行操作,在这一过程中会持续进行直到输出值p3达到1或最长序列长度时停止。
引入了温度参数\tau以增强序列的随机性,在取值范围为0到1的情况下,当趋近于0时,模型预测结果越具有确定性。
3.3Unconditional Generation
通过仅训练解码器而不使用编码器、输入数据以及潜在空间向量,并设定初始隐藏层的状态为零向量,则能够构建出一个纯生成模型
3.4 Training
该模型基于变分自编码器的方法,在其中的损失函数中包含了重建损失项L_R和KL散度_loss_项L_{KL}。
对于重建损失函数来说,则分别由位移量(\triangle x,\triangle y)对应的对数似然损失L_s以及笔触状态(p_1,p_2,p_3)对应的对数似然损失L_p构成。(注:其中N_s表示序列的实际长度)
具体而言,
L_s = -\frac{1}{N_\text{\max}} \sum\limits _{i=1}^{N_s}\log \left( \sum _{j=1 }^M \Pi _{j ,i } \mathcal { N } (\triangle x_i , \triangle y_i | μ _{x,j ,i }, μ _{y,j ,i }, σ _{x,j ,i }, σ _{y,j ,i }, ρ _{\text{x,y}, j ,i })
而
L_p = -\frac{ 1 } { N_\text{\max}} ∑ ∑ p{k, i }\log q{k, i }, L_R = L_s + L_p
KL散度损失用于衡量潜在空间中向量z与独立同分布的高斯向量之间的差异。(通过将不同草图置于潜在空间中较近的位置,并有助于插值过程具有意义))
L_{KL}=-\frac{1}{2N_z} (1+\hat{\sigma}-\mu^2-\exp(\hat{\sigma}) )\\ Loss=L_R+w_{KL}L_{KL}
4.Experiments
- 采用多种分类场景与单一分类策略,并通过不同的KL散度权重进行实验验证。
- 编码器基于双向LSTM结构,解码器基于HyperLSTM结构。
4.1 Conditional Reconstruction
分别对猫和猪的数据进行单独训练,并设置不同的τ值范围,在实验中发现当τ值较低时则重建结果趋于稳定;此外,在模型经过充分训练后仍能较好地维持数据特征。值得注意的是,在模型完成训练任务后仍具备良好的鲁棒性特点,在面对异常输入如一个牙刷时同样能够有效提取关键属性并完成相应的识别任务。

4.2 Latent Space Interpolation
4.3 Sketch Drawing Analogies(类比)
- 利用潜在空间中的数据插值技术来呈现草图的变化过程,并且优化KL散度参数至更高水平时,则能够形成更为清晰的数据流形结构;

4.4 Predicting Different Endings of Incomplete Sketches
- 一个应用点,根据初始的几笔,来补充完整的草图
5.Applications and Future Work;6.结论;略
附录
在数据预处理阶段, 对偏移量(\triangle x,\triangle y)进行归一化处理至方差为1的大小范围. 无需进行零均值归一化操作(因原始数据的均值较小).
在计算KL散度损失时,引入退火算法,效果更好
模型架构设置中采用M=20的混合配置方案;采用 Adam优化算法(学习率为 1 \times 1e^{-4}),并设定 recurrent dropout 保留率为 90%;解码器层包含 2k 神经单元;采用 batch_size=64 的训练批量大小;梯度裁剪阈值设为 1;最小 KL 散度设为 3倍于初始值(调节参数 R 设定为 R = \gamma^{step})。
The number of points cannot exceed 300. In this study, the Douglas-P克法 algorithm was employed to reduce the number of data points to below 200.
对于复杂的图像,重建效果较差,且更倾向于圆滑的效果
类别数不宜过多
其他略
