《AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks》是CVPR 2018 文本生成图像的文章,是StackGAN++的后续工作。
Abstract
在本文中作者提出了一个 Attentional Generative Ad-
versarial Network(AttnGAN),一种attention-driven的多stage的细粒度文本到图像生成器。
并借助一个深层注意多模态相似模型(deep attentional multimodal similarity model)来训练该生成器。
它首次表明 the layered attentional GAN 能够自动选择单词级别的condition来生成图像的不同部分。
大体思路
先看下模型结构:
该模型由两部分组成
-
attentional generative network
该部分使用了注意力机制来生成图像中的子区域,并且在生成每个子区域时还考虑了文本中与该子区域最相关的词。如下图所示:
- Deep Attentional Multimodal Similarity Model (DAMSM)
该部分用来计算生成的图像与文本的匹配程度。用来训练生成器。
Pipeline:
- 输入的文本通过一个Text Encoder 得到 sentence feature 和word features
- 用sentence feature 生成一个低分辨率的图像
- 基于
加入 word features 和setence feature 生成更高分辨率细粒度的图像
Attntional Generative Adversarial Network
下面将分别介绍之前提到的两个模块。
Attentional Generative Network
从Figure 2 中可以看出该attentional generative network共有m个生成器(,
,...,
)它们的输入分别是(
,
,...,
)生成(
,
,...
)
其中
z是noise,服从标准正态分布
是global sentence vector
e是word vector的矩阵
是将
转换为conditioning vertor的Conditioning Augmentation 方法
是第i个stage的attention model
有两个输入,e和h
h的每一列就是图像中一个子区域的feature vector
其中
生成器的目标函数定义如下:
第i个stage的生成器
unconditional loss决定了img是real或者fake,conditional loss 决定img和sentence是否匹配。
同理,判别器的loss定义为
Deep Attentional Multimodal Similarity Model
DAMSM学习了两个神经网络(text encoder-LSTM,image encoder -CNN),将图像的子区域和句子中的词映射到同一个语义空间来计算相似度,在训练生成器的时候就可以通过计算img-text similarity得到一个fine-grained loss
The text encoder
文本编码器是一个双向LSTM网络,用来提取文本描述(text description)的语义向量(semantic vectors)。
作者将该双向LSTM网络的hidden states 认为是词的语义(semantic meaning).
所有词的特征矩阵为e 纬度为DxT,D是 word vector的纬度,T是词的数量。第i列是第i个词的feature vector
并且,最后一次hidden states被认为是 global sentence vector 同样是D纬的。
The image encoder
图像编码器是一个将图像映射到语义空间的CNN网络。中间层的输出被认为是图像不同子区域的局部特征,后面层的输出是图像的全局特征。
作者从Inception-v3的mixed_6e层得到local feature f(768x289),f的每一列是一个子区域的feature vector,768是feature vector的纬度,289是子区域的个数。
并且,global feature 的纬度是2048
然后将img feature映射到text feature同一个空间,得到v和
The attention-driven image-text mathching score
该score是用来衡量img-sentence pair的匹配程度。
先计算句子中所有可能的单词对和图像中的子区域的相似度矩阵s。
s=v
是句子的第i个单词和图像的第j个子区域之间的点积相似度。
用如下方法归一化s
然后,作者构建了一个注意模型来计算每个词的region-context vector。
作者使用
最后整张图像(Q)和描述(D)的match score被定义为
DAMSM Loss
对一个batch的img-sentence 对,
与
匹配的后验概率被为
loss使用负对数后验概率(negative log posterior probability)
w for word
若将公式10重新新定义为
R(Q,D)=(
于是DAMSM loss 为:
结果
论文地址here
代码地址here
最后
该文章是很早之前就看过了的,也跑了他们的代码,感觉DAMSM部分还是有很多地方没吃透,结合代码看任然还是有很多不懂的地方。这次写出来有点新的体会,但是在DAMSM将img和text映射到同一个空间过后的的操作和多stage生成器部分还是有点模糊,可能还需要看看前作stack++。
之后有时间可能还会结合代码再写一篇blog解析一下不懂的地方。