mwptoolkit.module.Strategy.sampling

mwptoolkit.module.Strategy.sampling.topk_sampling(logits, temperature=1.0, top_k=0, top_p=0.9)[source]

Filter a distribution of logits using top-k and/or nucleus (top-p) filtering

Parameters
  • logits (torch.Tensor) – logits distribution

  • >0 (top_k) – keep only top k tokens with highest probability (top-k filtering).

  • >0.0 (top_p) – keep the top tokens with cumulative probability >= top_p (nucleus filtering).

Returns

the chosen index of token.

Return type

torch.Tensor