Source code for mwptoolkit.module.Attention.group_attention

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 11:10:10
# @File: group_attention.py


import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from mwptoolkit.utils.utils import clones


[docs]def group_mask(batch,type="self",pad=0): length = batch.shape[1] lis = [] if type=="self": for tok in batch: mask = np.zeros(tok.shape) mask = np.expand_dims(mask,-1) for ele in tok: if ele == pad:copy = np.zeros(length) else: copy = tok.copy() if ele != 1000:copy[copy == 1000] = 0 copy[copy != ele] = 0 copy[copy == ele] = 1 #print("self copy",copy) ''' if ele == 1000: copy[copy != ele] = 1 copy[copy == ele] = 0 ''' copy = np.expand_dims(copy,-1) mask = np.concatenate((mask,copy),axis=1) mask = mask[:,1:] mask = mask.transpose() mask = np.expand_dims(mask,0) lis.append(mask) res = np.concatenate(tuple(lis)) elif type=="between": for tok in batch: mask = np.zeros(tok.shape) mask = np.expand_dims(mask,-1) for ele in tok: if ele == pad:copy = np.zeros(length) else: copy = tok.copy() copy[copy==1000] = 0 copy[copy ==ele] = 0 copy[copy!= 0] = 1 ''' copy[copy != ele and copy != 1000] = 1 copy[copy == ele or copy == 1000] = 0 ''' copy = np.expand_dims(copy,-1) mask = np.concatenate((mask,copy),axis=1) mask = mask[:,1:] mask = mask.transpose() mask = np.expand_dims(mask,0) lis.append(mask) res = np.concatenate(tuple(lis)) elif type == "question": for tok in batch: mask = np.zeros(tok.shape) mask = np.expand_dims(mask,-1) for ele in tok: if ele == pad:copy = np.zeros(length) else: copy = tok.copy() copy[copy != 1000] = 0 copy[copy == 1000] = 1 if ele==1000: copy[copy==0] = -1 copy[copy==1] = 0 copy[copy==-1] = 1 copy = np.expand_dims(copy,-1) mask = np.concatenate((mask,copy),axis=1) mask = mask[:,1:] mask = mask.transpose() mask = np.expand_dims(mask,0) lis.append(mask) res = np.concatenate(tuple(lis)) else:return "error" return res
[docs]def src_to_mask(src, vocab_dict): src = src.cpu().numpy() batch_data_mask_tok = [] for encode_sen_idx in src: token = 1 mask = [0] * len(encode_sen_idx) for num in range(len(encode_sen_idx)): mask[num] = token if (encode_sen_idx[num] == vocab_dict["."] or encode_sen_idx[num] == vocab_dict[","]) \ and num != len(encode_sen_idx) - 1: token += 1 if encode_sen_idx[num]==0: mask[num] = 0 for num in range(len(encode_sen_idx)): if mask[num] == token and token != 1: mask[num] = 1000 batch_data_mask_tok.append(mask) return np.array(batch_data_mask_tok)
[docs]def attention(query, key, value, mask=None, dropout=None): """Compute Scaled Dot Product Attention Args: query (torch.Tensor): shape [batch_size, sequence_length, hidden_size]. key (torch.Tensor): shape [batch_size, sequence_length, hidden_size]. value (torch.Tensor): shape [batch_size, sequence_length, hidden_size]. mask (torch.Tensor): group attention mask, shape [batch_size, 4, sequence_length, sequence_length]. Returns: tuple(torch.Tensor, torch.Tensor): """ d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ /math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn
[docs]class GroupAttention(nn.Module): def __init__(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(GroupAttention, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(nn.Linear(d_model, d_model), 3) self.attn = None self.dropout = nn.Dropout(p=dropout) #self.split_list=split_list
[docs] def get_mask(self, src, split_list, pad=0): """ Args: src (torch.Tensor): source sequence, shape [batch_size, sequence_length]. split_list (list): group split index. pad (int): pad token index. Returns: torch.Tensor: group attention mask, shape [batch_size, 4, sequence_length, sequence_length]. """ device = src.device mask = self.src_to_mask(src, split_list) self.src_mask_self = torch.from_numpy(group_mask(mask,"self",pad).astype('uint8')).unsqueeze(1) self.src_mask_between = torch.from_numpy(group_mask(mask,"between",pad).astype('uint8')).unsqueeze(1) self.src_mask_question = torch.from_numpy(group_mask(mask, "question", pad).astype('uint8')).unsqueeze(1) self.src_mask_global = (src != pad).unsqueeze(-2).unsqueeze(1) self.src_mask_global = self.src_mask_global.expand(self.src_mask_self.shape) self.final = torch.cat((self.src_mask_between.to(device).bool(),self.src_mask_self.to(device).bool(),self.src_mask_global.to(device),self.src_mask_question.to(device).bool()),1) return self.final.to(device)
[docs] def forward(self, query, key, value, mask=None): """ Args: query (torch.Tensor): shape [batch_size, head_nums, sequence_length, dim_k]. key (torch.Tensor): shape [batch_size, head_nums, sequence_length, dim_k]. value (torch.Tensor): shape [batch_size, head_nums, sequence_length, dim_k]. mask (torch.Tensor): group attention mask, shape [batch_size, head_nums, sequence_length, sequence_length]. Returns: torch.Tensor: shape [batch_size, sequence_length, hidden_size]. """ if mask is not None and len(mask.shape)<4: # Same mask applied to all h heads. mask = mask.unsqueeze(1) else: mask = torch.cat((mask, mask), 1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # which is linears(query, key, value) # 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x)
[docs] def src_to_mask(self, src, split_list): src = src.cpu().numpy() batch_data_mask_tok = [] for encode_sen_idx in src: token = 1 mask = [0] * len(encode_sen_idx) for num in range(len(encode_sen_idx)): mask[num] = token if encode_sen_idx[num] in split_list and num != len(encode_sen_idx) - 1: token += 1 if encode_sen_idx[num]==0: mask[num] = 0 for num in range(len(encode_sen_idx)): if mask[num] == token and token != 1: mask[num] = 1000 batch_data_mask_tok.append(mask) return np.array(batch_data_mask_tok)