MWPToolkit API:
greedy_search()
Find the index of max logits
logits (torch.Tensor) – logits distribution
the chosen index of token
torch.Tensor