MWPToolkit API:
BERTGen
BERTGen.calculate_loss()
BERTGen.convert_idx2symbol()
BERTGen.decode()
BERTGen.decode_()
BERTGen.decoder_forward()
BERTGen.encoder_forward()
BERTGen.forward()
BERTGen.model_test()
BERTGen.predict()
BERTGen.training
GPT2
GPT2.calculate_loss()
GPT2.convert_idx2symbol()
GPT2.decode_()
GPT2.decoder_forward()
GPT2.encode_()
GPT2.forward()
GPT2.list2str()
GPT2.model_test()
GPT2.predict()
GPT2.training
RobertaGen
RobertaGen.calculate_loss()
RobertaGen.convert_idx2symbol()
RobertaGen.decode()
RobertaGen.decode_()
RobertaGen.decoder_forward()
RobertaGen.encoder_forward()
RobertaGen.forward()
RobertaGen.model_test()
RobertaGen.predict()
RobertaGen.training