使用 bi-LSTM 对文本进行特征提取

该部分内容通过代码注释的形式说明。

一、TextCNN 核心部分代码如下,这里主要关注 LSTM 类的内容。

import torch.nn as nn
import torch

class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()

        self.linear = nn.Linear(in_features=in_features,
                                out_features=out_features)
        self.init_params()

    def init_params(self):
        nn.init.kaiming_normal_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x):
        x = self.linear(x)
        return x


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, bidirectional, dropout):
        """
        Args:
            input_size: x 的特征维度
            hidden_size: 隐层的特征维度
            num_layers: LSTM 层数
        """
        super(LSTM, self).__init__()

        self.rnn = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout
        )

        self.init_params()

    def init_params(self):
        for i in range(self.rnn.num_layers):
            nn.init.orthogonal_(getattr(self.rnn, 'weight_hh_l{}'.format(i)))
            nn.init.kaiming_normal_(getattr(self.rnn, 'weight_ih_l{}'.format(i)))
            nn.init.constant_(getattr(self.rnn, 'bias_hh_l{}'.format(i)), val=0)
            nn.init.constant_(getattr(self.rnn, 'bias_ih_l{}'.format(i)), val=0)
            getattr(self.rnn, 'bias_hh_l{}'.format(i)).chunk(4)[1].fill_(1)

            if self.rnn.bidirectional:
                nn.init.orthogonal_(
                    getattr(self.rnn, 'weight_hh_l{}_reverse'.format(i)))
                nn.init.kaiming_normal_(
                    getattr(self.rnn, 'weight_ih_l{}_reverse'.format(i)))
                nn.init.constant_(
                    getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)), val=0)
                nn.init.constant_(
                    getattr(self.rnn, 'bias_ih_l{}_reverse'.format(i)), val=0)
                getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)).chunk(4)[1].fill_(1)

    def forward(self, x, lengths):
        ''' 
        关于 pack_padded_sequence 和 pad_packed_sequence 函数的用法见本文最后
        '''
        # x: [seq_len, batch_size, input_size]
        # lengths: [batch_size]
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # packed_x, packed_output: PackedSequence 对象
        # hidden: [num_layers * bidirectional, batch_size, hidden_size]
        # cell: [num_layers * bidirectional, batch_size, hidden_size]
        # Note: hidden 作为每个时间步的输出,cell 作为细胞状态。在相邻的时间步之间,cell 的值一般变化不大,但
        #              hidden 的差别一般会变化很大。
        packed_output, (hidden, cell) = self.rnn(packed_x)

        # output: [max_seq_len, batch_size, hidden_size * 2]
        # output_lengths: [batch_size]
        # 这里的 output 作为接下来全连接层的输入
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        return hidden, output


class TextRNN(nn.Module):

    def __init__(self, embedding_dim, output_dim, hidden_size, num_layers, bidirectional, dropout,
                 pretrained_embeddings):
        super(TextRNN, self).__init__()

        self.embedding = nn.Embedding.from_pretrained(
            pretrained_embeddings, freeze=False)
        self.rnn = LSTM(embedding_dim, hidden_size, num_layers, bidirectional, dropout)

        self.fc = Linear(hidden_size * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        text, text_lengths = x
        # text: [sent len, batch size]
        embedded = self.dropout(self.embedding(text))
        # embedded: [sent len, batch size, emb dim]

        hidden, outputs = self.rnn(embedded, text_lengths)

        hidden = self.dropout(
            torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))  # 连接最后一层的双向输出

        return self.fc(hidden)

二、完整 demo

import torch.nn as nn
import torch

from torchtext import data
from torchtext import vocab
from tqdm import tqdm


class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()

        self.linear = nn.Linear(in_features=in_features,
                                out_features=out_features)
        self.init_params()

    def init_params(self):
        nn.init.kaiming_normal_(self.linear.weight)
        nn.init.constant_(self.linear.bias, 0)

    def forward(self, x):
        x = self.linear(x)
        return x


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, bidirectional, dropout):
        """
        Args:
            input_size: x 的特征维度
            hidden_size: 隐层的特征维度
            num_layers: LSTM 层数
        """
        super(LSTM, self).__init__()

        self.rnn = nn.LSTM(
            input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout
        )

        self.init_params()

    def init_params(self):
        for i in range(self.rnn.num_layers):
            nn.init.orthogonal_(getattr(self.rnn, 'weight_hh_l{}'.format(i)))
            nn.init.kaiming_normal_(getattr(self.rnn, 'weight_ih_l{}'.format(i)))
            nn.init.constant_(getattr(self.rnn, 'bias_hh_l{}'.format(i)), val=0)
            nn.init.constant_(getattr(self.rnn, 'bias_ih_l{}'.format(i)), val=0)
            getattr(self.rnn, 'bias_hh_l{}'.format(i)).chunk(4)[1].fill_(1)

            if self.rnn.bidirectional:
                nn.init.orthogonal_(
                    getattr(self.rnn, 'weight_hh_l{}_reverse'.format(i)))
                nn.init.kaiming_normal_(
                    getattr(self.rnn, 'weight_ih_l{}_reverse'.format(i)))
                nn.init.constant_(
                    getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)), val=0)
                nn.init.constant_(
                    getattr(self.rnn, 'bias_ih_l{}_reverse'.format(i)), val=0)
                getattr(self.rnn, 'bias_hh_l{}_reverse'.format(i)).chunk(4)[1].fill_(1)

    def forward(self, x, lengths):
        # x: [seq_len, batch_size, input_size]
        # lengths: [batch_size]
        packed_x = nn.utils.rnn.pack_padded_sequence(x, lengths)

        # packed_x, packed_output: PackedSequence 对象
        # hidden: [num_layers * bidirectional, batch_size, hidden_size]
        # cell: [num_layers * bidirectional, batch_size, hidden_size]
        # Note: hidden 作为每个时间步的输出,cell 作为细胞状态。在相邻的时间步之间,cell 的值一般变化不大,但
        #              hidden 的差别一般会变化很大。
        packed_output, (hidden, cell) = self.rnn(packed_x)

        # output: [max_seq_len, batch_size, hidden_size * 2]
        # output_lengths: [batch_size]
        # 这里的 output 作为接下来全连接层的输入
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        return hidden, output


class TextRNN(nn.Module):

    def __init__(self, embedding_dim, output_dim, hidden_size, num_layers, bidirectional, dropout,
                 pretrained_embeddings):
        super(TextRNN, self).__init__()

        self.embedding = nn.Embedding.from_pretrained(
            pretrained_embeddings, freeze=False)
        self.rnn = LSTM(embedding_dim, hidden_size, num_layers, bidirectional, dropout)

        self.fc = Linear(hidden_size * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        text, text_lengths = x
        # text: [sent len, batch size]
        embedded = self.dropout(self.embedding(text))
        # embedded: [sent len, batch size, emb dim]

        hidden, outputs = self.rnn(embedded, text_lengths)

        hidden = self.dropout(
            torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))  # 连接最后一层的双向输出

        return self.fc(hidden)


if __name__ == '__main__':
    embedding_file = '/home/jason/Desktop/data/embeddings/glove/glove.840B.300d.txt'
    path = '/home/jason/Desktop/data/SST-2/'

    cache_dir = '.cache/'
    batch_size = 6
    vectors = vocab.Vectors(embedding_file, cache_dir)

    text_field = data.Field(tokenize='spacy',
                            lower=True,
                            include_lengths=True,
                            fix_length=5)
    label_field = data.LabelField(dtype=torch.long)

    train, dev, test = data.TabularDataset.splits(path=path,
                                                  train='train.tsv',
                                                  validation='dev.tsv',
                                                  test='test.tsv',
                                                  format='tsv',
                                                  skip_header=True,
                                                  fields=[('text', text_field), ('label', label_field)])

    text_field.build_vocab(train,
                           dev,
                           test,
                           max_size=25000,
                           vectors=vectors,
                           unk_init=torch.Tensor.normal_)
    label_field.build_vocab(train, dev, test)

    pretrained_embeddings = text_field.vocab.vectors
    labels = label_field.vocab.vectors

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test),
                                                                 batch_sizes=(batch_size, len(dev), len(test)),
                                                                 sort_key=lambda x: len(x.text),
                                                                 sort_within_batch=True,
                                                                 repeat=False,
                                                                 shuffle=True,
                                                                 device=device
                                                                )

    model = TextRNN(300, 2, 200, 2, True, 0.4, pretrained_embeddings)

    for step, batch in enumerate(tqdm(train_iter, desc="Iteration")):
        logits = model(batch.text)
        break

三、关于 pack_padded_sequence 和 pad_packed_sequence 函数的简单示例

train_x = [torch.tensor([1, 1, 1, 1, 1, 1, 1]),
           torch.tensor([3, 3, 3, 3, 3]),
           torch.tensor([6, 6])]
x = nn.utils.rnn.pad_sequence(train_x, batch_first=True)
print('>1: ', x)
pack_x = nn.utils.rnn.pack_padded_sequence(x, [7, 5, 2], batch_first=True)
print('>2: ', pack_x)
reverse_x = nn.utils.rnn.pad_packed_sequence(pack_x)
print('>3: ', reverse_x)

执行结果如下:

>1:  tensor([[1, 1, 1, 1, 1, 1, 1],
        [3, 3, 3, 3, 3, 0, 0],
        [6, 6, 0, 0, 0, 0, 0]])
>2:  PackedSequence(data=tensor([1, 3, 6, 1, 3, 6, 1, 3, 1, 3, 1, 3, 1, 1]), batch_sizes=tensor([3, 3, 2, 2, 2, 1, 1]), sorted_indices=None, unsorted_indices=None)
>3:  (tensor([[1, 3, 6],
        [1, 3, 6],
        [1, 3, 0],
        [1, 3, 0],
        [1, 3, 0],
        [1, 0, 0],
        [1, 0, 0]]), tensor([7, 5, 2]))
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,126评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,254评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,445评论 0 341
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,185评论 1 278
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,178评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,970评论 1 284
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,276评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,927评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,400评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,883评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,997评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,646评论 4 322
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,213评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,204评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,423评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,423评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,722评论 2 345