CS231n Spring 2019 Assignment 3—StyleTransfer-PyTorch风格迁移

在上次的Network Visualization之后,这次接着是Style Transfer,选择的还是PyTorch版本,个人认为这次代码还是比较简单有趣,也是直接在StyleTransfer-PyTorch.ipynb编写完成的。这次自己编一编,感觉再去看教程和视频就非常有感觉,也有很多收获(之前看lecture pdf一点感觉都没有)。

按教程看来,风格迁移(style transfer)的思想其实从纹理合成来的(Texture Synthesis)。用简短的话描述风格迁移就是:

拿两张图像,一张为内容图像(content source),一张为风格图像(style source),然后生成一张图像。通过构造一个总的损失函数,减小损失使生成的图像能够匹配内容特征与网络各层的风格特征,对生成图像像素进行梯度下降,减小总损失,就能使生成的图像既具有内容图像的内容,也有风格图像的风格。

教程开头给出的示意图:

风格迁移示意图

所以这里最主要的就是损失函数的构造,幸运的是,作业里面已经给出了损失函数的公式。分别为(注意,这里损失都是feature map之间的分离度的衡量):

  1. Content loss:生成图像与内容源图像某一层feature map之间的内容deviation
  2. Style loss:生成图像与风格源图像某几层feature map之间的风格deviation
  3. Total-variation regularization:全变差正则化,具体见下文

Content loss

就如上面所说,内容损失是用来描述生成图像与内容源图像某一层feature map之间的内容deviation,我们只关注网络某一层(\ell层)的特征图feature map为 A^\ell \in \mathbb{R}^{1 \times C_\ell \times H_\ell \times W_\ell}. C_\ell 是在\ell层通道数, H_\ellW_\ell 是高和宽。 我们将其reshape后计算(其实就内容损失不必要进行reshape,其实就是每个像素之间的差)。假设 F^\ell \in \mathbb{R}^{C_\ell \times M_\ell} 是生成图像的feature map ,P^\ell \in \mathbb{R}^{C_\ell \times M_\ell} 内容源图像的feature map,其中 M_\ell=H_\ell\times W_\ell 。给出内容损失函数如下,其中w_c是内容损失权值:
L_c = w_c \times \sum_{i,j} (F_{ij}^{\ell} - P_{ij}^{\ell})^2
其实就是两张图像像素值的平方差,然后加权求和,代码如下:

def content_loss(content_weight, content_current, content_original):
    """
    Compute the content loss for style transfer.
    
    Inputs:
    - content_weight: Scalar giving the weighting for the content loss.
    - content_current: features of the current image; this is a PyTorch Tensor of shape
      (1, C_l, H_l, W_l).
    - content_target: features of the content image, Tensor with shape (1, C_l, H_l, W_l).
    
    Returns:
    - scalar content loss
    """
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    N_l, C_l, H_l, W_l = content_current.shape
    F = content_current.view(C_l, H_l*W_l)
    P = content_original.view(C_l, H_l*W_l)
    ct_loss = content_weight * torch.sum((F - P)**2)
    return ct_loss

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

里面主要用了一个view函数,其作用和reshape函数一样

Style loss

损失函数比内容损失稍微麻烦一点,不能直接相减,因为这不能很好地衡量两张图像之间的风格偏差,这里用的是一种有效又计算简单的描述符——Gram Matrix.这个Gram描述符能较好的衡量图像的风格相关性,就如协方差(covariance)矩阵,但这个计算更容易高效一些,具体就是对于一个feature map F^\ell 形状为 (C_\ell, M_\ell), Gram matrix的形状是 (C_\ell, C_\ell) ,因为它的计算公式如下:
G_{ij}^\ell = \sum_k F^{\ell}_{ik} F^{\ell}_{jk}
一开始写了个循环的,结果后面生成图像非常慢,后来想到向量化,发现就是F^\ell乘以它的转置就能直接得到整个Gram Matrix,效率就高多了(后来发现lecture 13的57页就有,捂脸)
还有一点与内容损失不一样的地方是风格损失是用到多层的的feature map的Gram Matrix,一层之间的损失乘以权值w_{\ell}为:
L_s^\ell = w_\ell \sum_{i, j} \left(G^\ell_{ij} - A^\ell_{ij}\right)^2其中 G^\ell 是生成图像的feature map来的Gram matrix, A^\ell 是源风格图像的feature map的Gram Matrix 。最终的风格损失是多层损失的求和:
L_s = \sum_{\ell \in \mathcal{L}} L_s^\ell
以下是求Gram Matrix的代码(有关gram matrix的讨论可以看这里),主要有用到维度转换的一个函数torch.permute()和高维tensor矩阵相乘的函数torch.matmul()(注意torch.mm()函数只适用二维矩阵相乘):

def gram_matrix(features, normalize=True):
    """
    Compute the Gram matrix from features.
    
    Inputs:
    - features: PyTorch Tensor of shape (N, C, H, W) giving features for
      a batch of N images.
    - normalize: optional, whether to normalize the Gram matrix
        If True, divide the Gram matrix by the number of neurons (H * W * C)
    
    Returns:
    - gram: PyTorch Tensor of shape (N, C, C) giving the
      (optionally normalized) Gram matrices for the N input images.
    """
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
#     Way 1 inefficient
#     N, C, H, W = features.shape
#     F = features.reshape(N, C, H * W)
#     gram = torch.zeros(N, C, C)
#     for n in range(N):
#         for i in range(C):
#             for j in range(C):
#                 gram[n,i,j] = torch.sum(F[n,i,:] * F[n,j,:])
    
    # Way 2
    N, C, H, W = features.shape
    F = features.view(N, C, -1)
    F_T = F.permute(0, 2, 1)
    gram = F.matmul(F_T)

    if normalize:
        gram /= (H * W * C)
        
    return gram
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

根据上面的公式可以很容易写出风格损失代码,刚开始看输入参数的说明还不太明白,可以写看一下文件中下一个cell的style_loss_test()函数,能够知道各个输入参数代表什么:

# Now put it together in the style_loss function...
def style_loss(feats, style_layers, style_targets, style_weights):
    """
    Computes the style loss at a set of layers.
    
    Inputs:
    - feats: list of the features at every layer of the current image, as produced by
      the extract_features function.
    - style_layers: List of layer indices into feats giving the layers to include in the
      style loss.
    - style_targets: List of the same length as style_layers, where style_targets[i] is
      a PyTorch Tensor giving the Gram matrix of the source style image computed at
      layer style_layers[i].
    - style_weights: List of the same length as style_layers, where style_weights[i]
      is a scalar giving the weight for the style loss at layer style_layers[i].
      
    Returns:
    - style_loss: A PyTorch Tensor holding a scalar giving the style loss.
    """
    # Hint: you can do this with one for loop over the style layers, and should
    # not be very much code (~5 lines). You will need to use your gram_matrix function.
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    sl_loss = 0.0
    for i in range(style_layers.__len__()):
        sl_loss += style_weights[i] * torch.sum((gram_matrix(feats[style_layers[i]].clone()) - style_targets[i])**2)
    return sl_loss

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

Total-variation regularization

除了上面两种损失以外,还可以加入一种正则化手段:全变差正则化(Total-variation regularization),证明能够增加图像的平滑度,具体计算就是图片的相邻的行像素相减,相邻的列像素相减,如图所示(列相减的情况类似):

Total-variation regularization示意图

具体公式就是:

这可以不用循环写:

def tv_loss(img, tv_weight):
    """
    Compute total variation loss.
    
    Inputs:
    - img: PyTorch Variable of shape (1, 3, H, W) holding an input image.
    - tv_weight: Scalar giving the weight w_t to use for the TV loss.
    
    Returns:
    - loss: PyTorch Variable holding a scalar giving the total variation loss
      for img weighted by tv_weight.
    """
    # Your implementation should be vectorized and not require any loops!
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

    tv1 = torch.sum((img[:,:,:,1:] - img[:,:,:,:-1])**2)
    tv2 = torch.sum((img[:,:,1:] - img[:,:,:-1])**2)
    t_v_loss = tv_weight * (tv1 + tv2)
    return t_v_loss

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

之后还有:

  • 特征反演(Feature Inversion):从原始噪声,将风格权值设为0,开始生成与源内容图像相近的一张图
  • 纹理合成(texture synthesis):将内容权值设为0,合成一张包含有源风格图像中纹理特征的图像

这些都只要改一下权值就可以了

结果

200次迭代后的以tubingen为内容源图像,以composition_vii为风格源图像生成的风格迁移图:


风格迁移图1

200次迭代后的以tubingen为内容源图像,以the_scream为风格源图像生成的风格迁移图:


风格迁移图2

200次迭代后的以tubingen为内容源图像,以starry_night为风格源图像生成的风格迁移图:
风格迁移图3

200次迭代后的以tubingen为内容源图像,以starry_night为风格源图像生成的特征反演(Feature Inversion)图:


特征反演(Feature Inversion)图

200次迭代后的以tubingen为内容源图像,以starry_night为风格源图像生成的纹理合成(texture synthesis)图:
纹理合成(texture synthesis)图

链接

前后面的作业博文请见:

写得不好之处,还请多多包涵,可以在下方评论交流~~~

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 196,099评论 5 462
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 82,473评论 2 373
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 143,229评论 0 325
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,570评论 1 267
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,427评论 5 358
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,335评论 1 273
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,737评论 3 386
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,392评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,693评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,730评论 2 312
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,512评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,349评论 3 314
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,750评论 3 299
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,017评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,290评论 1 251
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,706评论 2 342
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,904评论 2 335

推荐阅读更多精彩内容