Source code for mwptoolkit.module.Attention.seq_attention

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 09:10:50
# @File: seq_attention.py


import math

import torch
from torch import nn
from torch.nn import functional as F

[docs]class SeqAttention(nn.Module): def __init__(self, hidden_size,context_size): super(SeqAttention, self).__init__() self.hidden_size=hidden_size self.context_size=context_size self.linear_out = nn.Linear(hidden_size*2, context_size)
[docs] def forward(self, inputs, encoder_outputs,mask): """ Args: inputs (torch.Tensor): shape [batch_size, 1, hidden_size]. encoder_outputs (torch.Tensor): shape [batch_size, sequence_length, hidden_size]. Returns: tuple(torch.Tensor, torch.Tensor): output, shape [batch_size, 1, context_size]. attention, shape [batch_size, 1, sequence_length]. """ batch_size = inputs.size(0) seq_length = encoder_outputs.size(1) attn = torch.bmm(inputs, encoder_outputs.transpose(1,2)) if mask is not None: attn.data.masked_fill_(mask, -float('inf')) attn = F.softmax(attn.view(-1, seq_length), dim=1).view(batch_size, -1, seq_length) mix = torch.bmm(attn, encoder_outputs) combined = torch.cat((mix, inputs), dim=2) output = torch.tanh(self.linear_out(combined.view(-1, 2*self.hidden_size)))\ .view(batch_size, -1, self.context_size) return output, attn
[docs]class Attention(nn.Module): """ Calculate attention Args: dim_value (int): Dimension of value. dim_query (int): Dimension of query. dim_hidden (int): Dimension of hidden layer in attention calculation. """ def __init__(self, dim_value, dim_query, dim_hidden=256, dropout_rate=0.5): super(Attention, self).__init__() self.relevant_score = \ MaskedRelevantScore(dim_value, dim_query, dim_hidden)
[docs] def forward(self, value, query, lens): """ Generate variable embedding with attention. Args: query (FloatTensor): Current hidden state, with size [batch_size, dim_query]. value (FloatTensor): Sequence to be attented, with size [batch_size, seq_len, dim_value]. lens (list of int): Lengths of values in a batch. Return: FloatTensor: Calculated attention, with size [batch_size, dim_value]. """ relevant_scores = self.relevant_score(value, query, lens) e_relevant_scores = torch.exp(relevant_scores) weights = e_relevant_scores / e_relevant_scores.sum(-1, keepdim=True) attention = (weights.unsqueeze(-1) * value).sum(1) return attention
[docs]class MaskedRelevantScore(nn.Module): """ Relevant score masked by sequence lengths. Args: dim_value (int): Dimension of value. dim_query (int): Dimension of query. dim_hidden (int): Dimension of hidden layer in attention calculation. """ def __init__(self, dim_value, dim_query, dim_hidden=256, dropout_rate=0.0): super(MaskedRelevantScore, self).__init__() self.dropout = nn.Dropout(dropout_rate) self.relevant_score = RelevantScore(dim_value, dim_query, dim_hidden, dropout_rate)
[docs] def forward(self, value, query, lens): """ Choose candidate from candidates. Args: query (torch.FloatTensor): Current hidden state, with size [batch_size, dim_query]. value (torch.FloatTensor): Sequence to be attented, with size [batch_size, seq_len, dim_value]. lens (list of int): Lengths of values in a batch. Return: torch.Tensor: Activation for each operand, with size [batch, max([len(os) for os in operands])]. """ relevant_scores = self.relevant_score(value, query) # make mask to mask out padding embeddings mask = torch.zeros_like(relevant_scores) for b, n_c in enumerate(lens): mask[b, n_c:] = -math.inf # apply mask relevant_scores += mask return relevant_scores
[docs]class RelevantScore(nn.Module): def __init__(self, dim_value, dim_query, hidden1, dropout_rate=0): super(RelevantScore, self).__init__() self.lW1 = nn.Linear(dim_value, hidden1, bias=False) self.lW2 = nn.Linear(dim_query, hidden1, bias=False) self.b = nn.Parameter( torch.normal(torch.zeros(1, 1, hidden1), 0.01)) self.tanh = nn.Tanh() self.lw = nn.Linear(hidden1, 1, bias=False) self.dropout = nn.Dropout(dropout_rate)
[docs] def forward(self, value, query): """ Args: value (torch.FloatTensor): shape [batch, seq_len, dim_value]. query (torch.FloatTensor): shape [batch, dim_query]. """ u = self.tanh(self.dropout( self.lW1(value) + self.lW2(query).unsqueeze(1) + self.b)) # u.size() == (batch, seq_len, dim_hidden) return self.lw(u).squeeze(-1)