mwptoolkit.modelΒΆ
- mwptoolkit.model.Seq2Seq
- mwptoolkit.model.Seq2Seq.dns
DNS
DNS.calculate_loss()
DNS.convert_idx2symbol()
DNS.convert_in_idx_2_out_idx()
DNS.convert_out_idx_2_in_idx()
DNS.decode()
DNS.decoder_forward()
DNS.encoder_forward()
DNS.filter_END()
DNS.filter_op()
DNS.forward()
DNS.init_decoder_inputs()
DNS.model_test()
DNS.predict()
DNS.rule1_filter()
DNS.rule2_filter()
DNS.rule3_filter()
DNS.rule4_filter()
DNS.rule5_filter()
DNS.rule_filter_()
DNS.training
- mwptoolkit.model.Seq2Seq.ept
- mwptoolkit.model.Seq2Seq.groupatt
GroupATT
GroupATT.calculate_loss()
GroupATT.convert_idx2symbol()
GroupATT.convert_in_idx_2_out_idx()
GroupATT.convert_out_idx_2_in_idx()
GroupATT.decode()
GroupATT.decoder_forward()
GroupATT.encoder_forward()
GroupATT.forward()
GroupATT.init_decoder_inputs()
GroupATT.model_test()
GroupATT.predict()
GroupATT.process_gap_encoder_decoder()
GroupATT.training
- mwptoolkit.model.Seq2Seq.mathen
- mwptoolkit.model.Seq2Seq.rnnencdec
RNNEncDec
RNNEncDec.calculate_loss()
RNNEncDec.convert_idx2symbol()
RNNEncDec.convert_in_idx_2_out_idx()
RNNEncDec.convert_out_idx_2_in_idx()
RNNEncDec.decode()
RNNEncDec.decoder_forward()
RNNEncDec.encoder_forward()
RNNEncDec.forward()
RNNEncDec.init_decoder_inputs()
RNNEncDec.model_test()
RNNEncDec.predict()
RNNEncDec.training
- mwptoolkit.model.Seq2Seq.rnnvae
- mwptoolkit.model.Seq2Seq.saligned
- mwptoolkit.model.Seq2Seq.transformer
Transformer
Transformer.calculate_loss()
Transformer.convert_idx2symbol()
Transformer.convert_in_idx_2_out_idx()
Transformer.convert_out_idx_2_in_idx()
Transformer.decode()
Transformer.decoder_forward()
Transformer.encoder_forward()
Transformer.forward()
Transformer.init_decoder_inputs()
Transformer.model_test()
Transformer.predict()
Transformer.training
- mwptoolkit.model.Seq2Seq.dns
- mwptoolkit.model.Seq2Tree
- mwptoolkit.model.Seq2Tree.berttd
- mwptoolkit.model.Seq2Tree.gts
- mwptoolkit.model.Seq2Tree.mwpbert
- mwptoolkit.model.Seq2Tree.sausolver
SAUSolver
SAUSolver.calculate_loss()
SAUSolver.convert_idx2symbol()
SAUSolver.decoder_forward()
SAUSolver.encoder_forward()
SAUSolver.evaluate_tree()
SAUSolver.forward()
SAUSolver.generate_tree_input()
SAUSolver.get_all_number_encoder_outputs()
SAUSolver.model_test()
SAUSolver.mse_loss()
SAUSolver.predict()
SAUSolver.train_tree()
SAUSolver.training
- mwptoolkit.model.Seq2Tree.treelstm
- mwptoolkit.model.Seq2Tree.trnn
TRNN
TRNN.ans_module_calculate_loss()
TRNN.ans_module_forward()
TRNN.calculate_loss()
TRNN.convert_idx2symbol()
TRNN.convert_in_idx_2_temp_idx()
TRNN.convert_temp_idx2symbol()
TRNN.convert_temp_idx_2_in_idx()
TRNN.forward()
TRNN.init_seq2seq_decoder_inputs()
TRNN.mask2num()
TRNN.model_test()
TRNN.predict()
TRNN.seq2seq_calculate_loss()
TRNN.seq2seq_decoder_forward()
TRNN.seq2seq_encoder_forward()
TRNN.seq2seq_forward()
TRNN.seq2seq_generate_t()
TRNN.seq2seq_generate_without_t()
TRNN.symbol2idx()
TRNN.template2tree()
TRNN.training
TRNN.tree2equation()
- mwptoolkit.model.Seq2Tree.tsn
TSN
TSN.build_graph()
TSN.convert_idx2symbol()
TSN.forward()
TSN.generate_tree_input()
TSN.get_all_number_encoder_outputs()
TSN.get_soft_target()
TSN.init_encoder_mask()
TSN.init_soft_target()
TSN.model_test()
TSN.predict()
TSN.student_calculate_loss()
TSN.student_net_1_decoder_forward()
TSN.student_net_2_decoder_forward()
TSN.student_net_decoder_forward()
TSN.student_net_encoder_forward()
TSN.student_net_forward()
TSN.student_test()
TSN.teacher_calculate_loss()
TSN.teacher_net_decoder_forward()
TSN.teacher_net_encoder_forward()
TSN.teacher_net_forward()
TSN.teacher_test()
TSN.training
cosine_loss()
cosine_sim()
soft_cross_entropy_loss()
soft_target_loss()
- mwptoolkit.model.Graph2Tree
- mwptoolkit.model.Graph2Tree.graph2tree
Graph2Tree
Graph2Tree.build_graph()
Graph2Tree.calculate_loss()
Graph2Tree.convert_idx2symbol()
Graph2Tree.decoder_forward()
Graph2Tree.encoder_forward()
Graph2Tree.forward()
Graph2Tree.generate_tree_input()
Graph2Tree.get_all_number_encoder_outputs()
Graph2Tree.model_test()
Graph2Tree.predict()
Graph2Tree.training
- mwptoolkit.model.Graph2Tree.multiencdec
MultiEncDec
MultiEncDec.attn_decoder_forward()
MultiEncDec.calculate_loss()
MultiEncDec.convert_idx2symbol1()
MultiEncDec.convert_idx2symbol2()
MultiEncDec.decoder_forward()
MultiEncDec.encoder_forward()
MultiEncDec.forward()
MultiEncDec.generate_decoder_input()
MultiEncDec.generate_tree_input()
MultiEncDec.get_all_number_encoder_outputs()
MultiEncDec.model_test()
MultiEncDec.predict()
MultiEncDec.training
MultiEncDec.tree_decoder_forward()
- mwptoolkit.model.Graph2Tree.graph2tree
- mwptoolkit.model.PreTrain