mwptoolkit.module.Strategy.sampling¶
- mwptoolkit.module.Strategy.sampling.topk_sampling(logits, temperature=1.0, top_k=0, top_p=0.9)[source]¶
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
- Parameters
logits (torch.Tensor) – logits distribution
>0 (top_k) – keep only top k tokens with highest probability (top-k filtering).
>0.0 (top_p) – keep the top tokens with cumulative probability >= top_p (nucleus filtering).
- Returns
the chosen index of token.
- Return type
torch.Tensor