机器翻译load pt文件出错

802 次浏览0个评论2021年09月06日

我有一个机器翻译模型,已经训练好,而且导出了pt文件。但是,使用的时候,机器翻译的结果跟训练后未关闭应用时的结果不一样。调用代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import time
import math
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('./cmn.txt', 'r', encoding='utf-8') as f:
    data = f.read()
data = data.strip()
data = data.split('\n')
print('样本数:\n', len(data))
print('\n样本示例:')
# 分割英文数据和中文数据
en_data = [line.split('\t')[0] for line in data]
ch_data = [line.split('\t')[1] for line in data]
print('英文数据:\n', en_data[:10])
print('\n中文数据:\n', ch_data[:10])
# 按字符级切割,并添加
en_token_list = [[char for char in line]+[""] for line in en_data]
ch_token_list = [[char for char in line]+[""] for line in ch_data]
print('英文数据:\n', en_token_list[:2])
print('\n中文数据:\n', ch_token_list[:2])
# 基本字典
basic_dict = {'': 0, '': 1, '': 2, '': 3}
# 分别生成中英文字典
en_vocab = set(''.join(en_data))
en2id = {char: i+len(basic_dict) for i, char in enumerate(en_vocab)}
en2id.update(basic_dict)
id2en = {v: k for k, v in en2id.items()}

# 分别生成中英文字典
ch_vocab = set(''.join(ch_data))
ch2id = {char: i+len(basic_dict) for i, char in enumerate(ch_vocab)}
ch2id.update(basic_dict)
id2ch = {v: k for k, v in ch2id.items()}
# 利用字典,映射数据
en_num_data = [[en2id[en] for en in line] for line in en_token_list]
ch_num_data = [[ch2id[ch] for ch in line] for line in ch_token_list]

print('char:', en_data[1])
print('index:', en_num_data[1])
print('char:', ch_data[1])
print('index:', ch_num_data[1])

def translate(
        model,
        sample,
        idx2token=None
        ):
    model.predict = True
    model.eval()

    # shape = [seq_len, 1]
    input_batch = sample["src"]
    # list
    input_len = sample["src_len"]

    output_tokens = model(input_batch, input_len)
    output_tokens = [idx2token[t] for t in output_tokens]

    return "".join(output_tokens)

INPUT_DIM = len(en2id)
OUTPUT_DIM = len(ch2id)
# 超参数
BATCH_SIZE = 32
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
LEARNING_RATE = 1e-4
N_EPOCHS = 100
CLIP = 1

bidirectional = True
attn_method = "general"
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM,
              N_LAYERS, ENC_DROPOUT, bidirectional)
dec = AttnDecoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS,
                  DEC_DROPOUT, bidirectional, attn_method)
model = Seq2Seq(enc, dec, device, basic_dict=basic_dict).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 数据集
train_set = TranslationDataset(en_num_data, ch_num_data)
train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, collate_fn=padding_batch)
best_valid_loss = float('inf')

# 加载最优权重
model.load_state_dict(torch.load("en2ch-attn-model.pt"))

while True:
    in_data = input("English: ")
    input_data = []
    input_data.append(in_data)
    en_token_list = [[char for char in line]+[""] for line in input_data]
    en_num_data = [[en2id[en] for en in line] for line in en_token_list]

    en_tokens = list(filter(lambda x: x != 0, en_num_data[0]))  # 过滤零
    sentence = [id2en[t] for t in en_tokens]
    test_sample = {}
    test_sample["src"] = torch.tensor(
        en_tokens, dtype=torch.long, device=device).reshape(-1, 1)
    test_sample["src_len"] = [len(en_tokens)]
    print("【机器翻译】")
    print(translate(model, test_sample, id2ch), end="\n\n")

提交评论

请登录后评论

用户评论

    当前暂无评价,快来发表您的观点吧...

更多相关好文

    当前暂无更多相关好文推荐...

    公告提示

    Pytorch是Facebook的AI研究团队发布了一个Python工具包,是Python优先的深度学习框架。作为numpy的替代品;使用强大的GPU能力,提供最大的灵活性和速度,实现了机器学习框架Torch在Python语言环境的执行,基于python且具备强大GPU加速的张量和动态神经网络。