Source code for mwptoolkit.module.Strategy.sampling

# -*- encoding: utf-8 -*-
# @Author: Yihuai Lan
# @Time: 2021/08/29 22:12:37
# @File: sampling.py


import torch
from torch.nn import functional as F

[docs]def topk_sampling(logits, temperature=1.0, top_k=0, top_p=0.9): r""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits (torch.Tensor): logits distribution top_k >0: keep only top k tokens with highest probability (top-k filtering). top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Return: torch.Tensor: the chosen index of token. """ logits = logits / temperature top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: values = torch.topk(logits, top_k)[0] # B x top_k batch_mins = values[:, :, -1].expand_as(logits.squeeze(1)).unsqueeze(1) logits = torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) if 0.0 < top_p < 1.0: probs = torch.softmax(logits, dim=-1) sorted_probs, _ = torch.sort(probs, descending=True, dim=-1) cumprobs = sorted_probs.cumsum(dim=-1) # Create mask for all cumulative probabilities less than p mask = cumprobs < top_p # First mask must always be pickable mask = F.pad(mask[:, :, :-1], (1, 0, 0, 0), value=1) masked_probs = torch.where(mask, sorted_probs, torch.tensor(float('inf')).to(probs)) batch_mins = masked_probs.min(dim=-1, keepdim=True)[0].expand_as(logits) # Mask out all logits (tail) that are too small logits = torch.where(probs < batch_mins, torch.tensor(float('-inf')).to(logits), logits) probabilities = F.softmax(logits, dim=-1) probabilities = probabilities.squeeze(1) token_idx = torch.multinomial(probabilities, 1) return token_idx