Source code for mwptoolkit.module.Graph.gcn

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 21:49:49
# @File: gcn.py


import torch
from torch import nn
from torch.nn import functional as F

from mwptoolkit.module.Layer.graph_layers import GraphConvolution


[docs]class GCN(nn.Module): def __init__(self, in_feat_dim, nhid, out_feat_dim, dropout): super(GCN, self).__init__() self.gc1 = GraphConvolution(in_feat_dim, nhid) self.gc2 = GraphConvolution(nhid, out_feat_dim) self.dropout = dropout
[docs] def forward(self, x, adj): """ Args: x (torch.Tensor): input features, shape [batch_size, node_num, in_feat_dim] adj (torch.Tensor): adjacency matrix, shape [batch_size, node_num, node_num] Returns: torch.Tensor: gcn_enhance_feature, shape [batch_size, node_num, out_feat_dim] """ x = F.relu(self.gc1(x, adj)) x = F.dropout(x, self.dropout, training=self.training) x = self.gc2(x, adj) return x