Source code for mwptoolkit.module.Attention.tree_attention

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 09:11:19
# @File: tree_attention.py


import torch
from torch import nn


[docs]class TreeAttention(nn.Module): def __init__(self, input_size, hidden_size): super(TreeAttention, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.attn = nn.Linear(hidden_size + input_size, hidden_size) self.score = nn.Linear(hidden_size, 1)
[docs] def forward(self, hidden, encoder_outputs, seq_mask=None): """ Args: hidden (torch.Tensor): hidden representation, shape [1, batch_size, hidden_size] encoder_outputs (torch.Tensor): output from encoder, shape [sequence_length, batch_size, hidden_size]. seq_mask (torch.Tensor): sequence mask, shape [batch_size, sequence_length]. Returns: attn_energies (torch.Tensor): attention energies, shape [batch_size, 1, sequence_length]. """ max_len = encoder_outputs.size(0) repeat_dims = [1] * hidden.dim() repeat_dims[0] = max_len hidden = hidden.repeat(*repeat_dims) # S x B x H this_batch_size = encoder_outputs.size(1) energy_in = torch.cat((hidden, encoder_outputs), 2).view(-1, self.input_size + self.hidden_size) score_feature = torch.tanh(self.attn(energy_in)) attn_energies = self.score(score_feature) # (S x B) x 1 attn_energies = attn_energies.squeeze(1) attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1) # B x S if seq_mask is not None: attn_energies = attn_energies.masked_fill_(seq_mask.bool(), -1e12) attn_energies = nn.functional.softmax(attn_energies, dim=1) # B x S return attn_energies.unsqueeze(1)