Source code for mwptoolkit.module.Graph.graph_module

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:00:28
# @File: graph_module.py


import torch
from torch import nn
from torch.nn import functional as F

from mwptoolkit.module.Layer.graph_layers import PositionwiseFeedForward,LayerNorm
from mwptoolkit.module.Graph.gcn import GCN

[docs]class Graph_Module(nn.Module): def __init__(self, indim, hiddim, outdim, dropout=0.3): super(Graph_Module, self).__init__() """ Args: indim: dimensionality of input node features hiddim: dimensionality of the joint hidden embedding outdim: dimensionality of the output node features combined_feature_dim: dimensionality of the joint hidden embedding for graph K: number of graph nodes/objects on the image """ self.in_dim = indim self.h = 4 self.d_k = outdim//self.h self.graph = nn.ModuleList() for _ in range(self.h): self.graph.append(GCN(indim,hiddim,self.d_k,dropout)) self.feed_foward = PositionwiseFeedForward(indim, hiddim, outdim, dropout) self.norm = LayerNorm(outdim)
[docs] def get_adj(self, graph_nodes): """ Args: graph_nodes (torch.Tensor): input features, shape [batch_size, node_num, in_feat_dim] Returns: torch.Tensor: adjacency matrix, shape [batch_size, node_num, node_num] """ self.K = graph_nodes.size(1) graph_nodes = graph_nodes.contiguous().view(-1, self.in_dim) # layer 1 h = self.edge_layer_1(graph_nodes) h = F.relu(h) # layer 2 h = self.edge_layer_2(h) h = F.relu(h) # outer product h = h.view(-1, self.K, self.combined_dim) adjacency_matrix = torch.matmul(h, h.transpose(1, 2)) adjacency_matrix = self.b_normal(adjacency_matrix) return adjacency_matrix
[docs] def normalize(self, A, symmetric=True): """ Args: A (torch.Tensor): adjacency matrix (node_num, node_num) Returns: adjacency matrix (node_num, node_num) """ A = A + torch.eye(A.size(0)).cuda().float() d = A.sum(1) if symmetric: # D = D^{-1/2} D = torch.diag(torch.pow(d, -0.5)) return D.mm(A).mm(D) else : D = torch.diag(torch.pow(d,-1)) return D.mm(A)
[docs] def b_normal(self, adj): batch = adj.size(0) for i in range(batch): adj[i] = self.normalize(adj[i]) return adj
[docs] def forward(self, graph_nodes, graph): """ Args: graph_nodes (torch.Tensor):input features, shape [batch_size, node_num, in_feat_dim] Returns: torch.Tensor: graph_encode_features, shape [batch_size, node_num, out_feat_dim] """ nbatches = graph_nodes.size(0) mbatches = graph.size(0) if nbatches != mbatches: graph_nodes = graph_nodes.transpose(0, 1) if not bool(graph.numel()): adj = self.get_adj(graph_nodes) adj_list = [adj,adj,adj,adj] else: adj = graph.float() adj_list = [adj[:,1,:],adj[:,1,:],adj[:,4,:],adj[:,4,:]] g_feature = \ tuple([l(graph_nodes,x) for l, x in zip(self.graph,adj_list)]) g_feature = self.norm(torch.cat(g_feature,2)) + graph_nodes graph_encode_features = self.feed_foward(g_feature) + g_feature return adj, graph_encode_features
[docs]class Parse_Graph_Module(nn.Module): def __init__(self, hidden_size): super(Parse_Graph_Module, self).__init__() self.hidden_size = hidden_size self.node_fc1 = nn.Linear(hidden_size, hidden_size) self.node_fc2 = nn.Linear(hidden_size, hidden_size) self.node_out = nn.Linear(hidden_size * 2, hidden_size)
[docs] def normalize(self, graph, symmetric=True): d = graph.sum(1) if symmetric: D = torch.diag(torch.pow(d, -0.5)) return D.mm(graph).mm(D) else : D = torch.diag(torch.pow(d,-1)) return D.mm(graph)
[docs] def forward(self, node, graph): graph = graph.float() batch_size = node.size(0) for i in range(batch_size): graph[i] = self.normalize(graph[i]) node_info = torch.relu(self.node_fc1(torch.matmul(graph, node))) node_info = torch.relu(self.node_fc2(torch.matmul(graph, node_info))) agg_node_info = torch.cat((node, node_info), dim=2) agg_node_info = torch.relu(self.node_out(agg_node_info)) return agg_node_info
[docs]class Num_Graph_Module(nn.Module): def __init__(self, node_dim): super(Num_Graph_Module, self).__init__() self.node_dim = node_dim self.node1_fc1 = nn.Linear(node_dim, node_dim) self.node1_fc2 = nn.Linear(node_dim, node_dim) self.node2_fc1 = nn.Linear(node_dim, node_dim) self.node2_fc2 = nn.Linear(node_dim, node_dim) self.graph_weight = nn.Linear(node_dim * 4, node_dim) self.node_out = nn.Linear(node_dim * 2, node_dim)
[docs] def normalize(self, graph, symmetric=True): d = graph.sum(1) if symmetric: D = torch.diag(torch.pow(d, -0.5)) return D.mm(graph).mm(D) else : D = torch.diag(torch.pow(d,-1)) return D.mm(graph)
[docs] def forward(self, node, graph1, graph2): graph1 = graph1.float() graph2 = graph2.float() batch_size = node.size(0) for i in range(batch_size): graph1[i] = self.normalize(graph1[i], False) graph2[i] = self.normalize(graph2[i], False) node_info1 = torch.relu(self.node1_fc1(torch.matmul(graph1, node))) node_info1 = torch.relu(self.node1_fc2(torch.matmul(graph1, node_info1))) node_info2 = torch.relu(self.node2_fc1(torch.matmul(graph2, node))) node_info2 = torch.relu(self.node2_fc2(torch.matmul(graph2, node_info2))) gate = torch.cat((node_info1, node_info2, node_info1+node_info2, node_info1-node_info2), dim=2) gate = torch.sigmoid(self.graph_weight(gate)) node_info = gate * node_info1 + (1-gate) * node_info2 agg_node_info = torch.cat((node, node_info), dim=2) agg_node_info = torch.relu(self.node_out(agg_node_info)) return agg_node_info