mwptoolkit.module.Graph.gcn

class mwptoolkit.module.Graph.gcn.GCN(in_feat_dim, nhid, out_feat_dim, dropout)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, adj)[source]
Parameters
  • x (torch.Tensor) – input features, shape [batch_size, node_num, in_feat_dim]

  • adj (torch.Tensor) – adjacency matrix, shape [batch_size, node_num, node_num]

Returns

gcn_enhance_feature, shape [batch_size, node_num, out_feat_dim]

Return type

torch.Tensor

training: bool