时间序列预测方法之 WaveNet

本文链接个人站 | 简书 | CSDN
版权声明:除特别声明外,本博客文章均采用 BY-NC-SA 许可协议。转载请注明出处。

最近打算分享一些基于深度学习的时间序列预测方法。这是第四篇。

前面已经分享了两个基于 RNN 的模型(DeepARDeepState)和一个基于 Attention 的模型(Transformer),今次将会介绍一个基于 CNN 的模型。

Google DeepMind 在 2016 年发表了 WaveNet: A generative model for raw audio 。这篇文章介绍了一种用于音频生成的神经网络 WaveNet。该模型在文字转语音(Text-to-Speech, TTS)任务上取得了极佳的效果,与当时业界已有的模型不同,它合成的声音与真人的发音非常接近。目前 WaveNet 已经被应用在 Google Assistant 语音助手中。

WaveNet 是一个自回归概率模型,它将音波 \mathrm x = \{x_1, \cdots, x_T\} 的联合概率分布建模为
p(\mathrm x) = \prod_{t=1}^Tp(x_t|x_1, \cdots, x_{t-1})
这种建模方式与 DeepAR 十分类似,因而可以很自然地迁移到时间序列预测的任务上——说起来音频信号本身也是一种时间序列。Amazon 在其开源的 GluonTS 库中就实现了一个基于 WaveNet 的时间序列预测模型。

接下来我会主要基于 DeepMind 的文章介绍一下 WaveNet 的网络结构,并给出一个 demo。

Model

WaveNet 网络的基本单元是我们在介绍 Transformer 时提到过的因果卷积(causal convolution)。这种卷积方式可以确保模型在每一个时间步都不会接触到未来时间步的信息,如下图所示。

Visualization of a stack of causal convolutional layers.

卷积的缺陷在于其本质上捕捉的是局部信息,需要通过增大卷积核或(和)增加模型层数的方式来扩大感受野。对于较长的序列,上述方法就杯水车薪了。为了解决这个问题,WaveNet 采用了空洞因果卷积(dilated causal convolution)。所谓空洞卷积,就是以一定的步长跳过输入值,将卷积核应用到超过其自身尺寸的区域,从而在层数不多的情况下也能拥有较大的感受野。如下图所示。


Visualization of a stack of dilated causal convolutional layers.

为了加快收敛以及训练更深的模型,WaveNet 使用了 residual connection 和 skip connection。完整的网络结构如下图所示。


Overview of the residual block and the entire architecture.

注意到在残差块内部,使用了 gated activation unit 作为激活函数,即
output = \tanh(W_{f,k}* input) \odot \sigma(W_{g,k}* input)
文章声称在建模音频信号时这种激活函数的表现优于 ReLU。另一个有意思的细节是,WaveNet 在最后使用 Softmax 输出概率分布。Softmax 对分布的形状没有任何假设,文章认为这使得它适合用来拟合任意的分布。当然,这意味着需要将连续的数据离散化,并且训练时需要采用交叉熵作为损失函数。

Amazon 在进行时间序列预测时使用的网络结构与上述的内容并无二致,因而不需要再做过多介绍。预测时也是跟 DeepAR 一样的套路,先用自回归祖先采用获取一批样本,再利用样本计算感兴趣的统计量。这里比较有意思的是,在采样的时候,Amazon 给 Softmax 加了一个所谓的“温度” T,形如
p_i = \frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)}
不难发现,“温度”越高,不同取值的概率越接近,得到的分布越平滑。这个做法十有八九是从知识蒸馏中借鉴过来的。

Code

Amazon 的实现是基于 MXNet 的。我们还是延续之前的做法,自己用 TensorFlow 构建一个 demo。

下面给出残差块和网络结构的定义:

import tensorflow as tf

class DilatedCausalResidual(tf.keras.layers.Layer):
    """
    空洞因果卷积残差块
    """
    def __init__(self, dilation_rate, kernel_size, residual_channels, dilation_channels, skip_channels, return_dense):
        super().__init__()
        self.return_dense = return_dense
        
        self.conv_tanh = tf.keras.layers.Conv1D(
            filters=dilation_channels,
            kernel_size=kernel_size,
            strides=1,
            padding='causal',
            dilation_rate=dilation_rate,
            activation='tanh'
        )
        self.conv_sigmoid = tf.keras.layers.Conv1D(
            filters=dilation_channels,
            kernel_size=kernel_size,
            strides=1,
            padding='causal',
            dilation_rate=dilation_rate,
            activation='sigmoid'
        )
        self.conv_skip = tf.keras.layers.Conv1D(
            filters=skip_channels,
            kernel_size=1,
            strides=1,
            padding='causal'
        )
        if self.return_dense:
            self.conv_residual = tf.keras.layers.Conv1D(
                filters=residual_channels,
                kernel_size=1,
                strides=1,
                padding='causal'
            )
        else:
            self.conv_residual = None
    
    
    def call(self, inputs):
        tanh = self.conv_tanh(inputs)
        sigmoid = self.conv_sigmoid(inputs)
        
        z = tf.multiply(tanh, sigmoid)
        skip = self.conv_skip(z)
        
        if not self.return_dense:
            return skip, None
        
        residual = self.conv_residual(z)
        dense = inputs + residual
        
        return skip, dense


class WaveNet(tf.keras.models.Model):
    """
    WaveNet 模型
    """
    def __init__(self, dilation_rates, kernel_size, residual_channels, dilation_channels, skip_channels, logits_channels):
        super().__init__()
        
        self.causal_conv = tf.keras.layers.Conv1D(
            filters=residual_channels,
            kernel_size=kernel_size,
            strides=1,
            padding='causal'
        )
        
        self.residual_stacks = []
        for i, dilation_rate in enumerate(dilation_rates):
            is_not_last = i < len(dilation_rates) - 1
            self.residual_stacks.append(
                DilatedCausalResidual(
                    dilation_rate=dilation_rate,
                    kernel_size=kernel_size,
                    residual_channels=residual_channels,
                    dilation_channels=dilation_channels,
                    skip_channels=skip_channels,
                    return_dense=is_not_last
                )
            )
        
        self.relu = tf.keras.layers.Activation('relu')
        self.conv1 = tf.keras.layers.Conv1D(
            filters=skip_channels,
            kernel_size=1,
            strides=1
        )
        self.conv2 = tf.keras.layers.Conv1D(
            filters=logits_channels,
            kernel_size=1,
            strides=1
        )
        
    def call(self, inputs):
        o = self.causal_conv(inputs)
        
        skip_outs = []
        for block in self.residual_stacks:
            skip, o = block(o)
            skip_outs.append(skip)
        
        total = sum(skip_outs)
        outputs = self.relu(total)
        outputs = self.conv1(outputs)
        outputs = self.relu(outputs)
        outputs = self.conv2(outputs)
        
        # 这里直接输出 logits
        return outputs

训练过程与之前介绍的模型大同小异,这里不再赘述了。

为了验证代码,我们随机生成一个带有周期的时间序列。下图展示了这个序列的一部分数据点。


时间序列

简单起见,我们没有加入额外的特征。

经过训练后用于预测,效果如下图所示,其中阴影部分表示 0.05 分位数 ~ 0.95 分位数的区间。


预测效果

与其它模型对比

  • WaveNet 与之前介绍的 DeepAR 和 Transformer 类似,都是自回归模型。
  • DeepAR 基于 RNN,在训练的时候无法并行。Transformer 基于 Attention,WaveNet 基于 CNN,它们训练时都是并行的。
  • 自回归模型在预测时都无法并行。但 DeepAR 的 RNN 结构保留了网络状态,采样时的计算量最小。Transformer 每采样一个时间步都需要进行一次全局 Self Attention,计算量最大。从 Amazon 给出的实现中可以看到,WaveNet 在经过适当改造之后,能够利用缓存减少不必要的重复计算,从而加速采样。
  • DeepAR 和 Transformer 学习的是概率分布的参数,WaveNet 学习的是 Softmax 的 logits。由于前两者假定了分布的形状,因而预测的结果看上去较为规则。将 WaveNet 的学习目标改为与前两者一样似乎也未尝不可,不过效果如何需要通过实验来验证。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
禁止转载,如需转载请通过简信或评论联系作者。
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 210,978评论 6 490
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 89,954评论 2 384
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 156,623评论 0 345
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,324评论 1 282
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,390评论 5 384
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,741评论 1 289
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,892评论 3 405
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,655评论 0 266
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,104评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,451评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,569评论 1 340
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,254评论 4 328
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,834评论 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,725评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,950评论 1 264
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,260评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,446评论 2 348