机器翻译load pt文件出错
我有一个机器翻译模型,已经训练好,而且导出了pt文件。但是,使用的时候,机器翻译的结果跟训练后未关闭应用时的结果不一样。调用代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import time
import math
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with open('./cmn.txt', 'r', encoding='utf-8') as f:
data = f.read()
data = data.strip()
data = data.split('\n')
print('样本数:\n', len(data))
print('\n样本示例:')
# 分割英文数据和中文数据
en_data = [line.split('\t')[0] for line in data]
ch_data = [line.split('\t')[1] for line in data]
print('英文数据:\n', en_data[:10])
print('\n中文数据:\n', ch_data[:10])
# 按字符级切割,并添加
en_token_list = [[char for char in line]+[""] for line in en_data]
ch_token_list = [[char for char in line]+[""] for line in ch_data]
print('英文数据:\n', en_token_list[:2])
print('\n中文数据:\n', ch_token_list[:2])
# 基本字典
basic_dict = {'': 0, '': 1, '': 2, '': 3}
# 分别生成中英文字典
en_vocab = set(''.join(en_data))
en2id = {char: i+len(basic_dict) for i, char in enumerate(en_vocab)}
en2id.update(basic_dict)
id2en = {v: k for k, v in en2id.items()}
# 分别生成中英文字典
ch_vocab = set(''.join(ch_data))
ch2id = {char: i+len(basic_dict) for i, char in enumerate(ch_vocab)}
ch2id.update(basic_dict)
id2ch = {v: k for k, v in ch2id.items()}
# 利用字典,映射数据
en_num_data = [[en2id[en] for en in line] for line in en_token_list]
ch_num_data = [[ch2id[ch] for ch in line] for line in ch_token_list]
print('char:', en_data[1])
print('index:', en_num_data[1])
print('char:', ch_data[1])
print('index:', ch_num_data[1])
def translate(
model,
sample,
idx2token=None
):
model.predict = True
model.eval()
# shape = [seq_len, 1]
input_batch = sample["src"]
# list
input_len = sample["src_len"]
output_tokens = model(input_batch, input_len)
output_tokens = [idx2token[t] for t in output_tokens]
return "".join(output_tokens)
INPUT_DIM = len(en2id)
OUTPUT_DIM = len(ch2id)
# 超参数
BATCH_SIZE = 32
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
LEARNING_RATE = 1e-4
N_EPOCHS = 100
CLIP = 1
bidirectional = True
attn_method = "general"
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM,
N_LAYERS, ENC_DROPOUT, bidirectional)
dec = AttnDecoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS,
DEC_DROPOUT, bidirectional, attn_method)
model = Seq2Seq(enc, dec, device, basic_dict=basic_dict).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 数据集
train_set = TranslationDataset(en_num_data, ch_num_data)
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE, collate_fn=padding_batch)
best_valid_loss = float('inf')
# 加载最优权重
model.load_state_dict(torch.load("en2ch-attn-model.pt"))
while True:
in_data = input("English: ")
input_data = []
input_data.append(in_data)
en_token_list = [[char for char in line]+[""] for line in input_data]
en_num_data = [[en2id[en] for en in line] for line in en_token_list]
en_tokens = list(filter(lambda x: x != 0, en_num_data[0])) # 过滤零
sentence = [id2en[t] for t in en_tokens]
test_sample = {}
test_sample["src"] = torch.tensor(
en_tokens, dtype=torch.long, device=device).reshape(-1, 1)
test_sample["src_len"] = [len(en_tokens)]
print("【机器翻译】")
print(translate(model, test_sample, id2ch), end="\n\n")
用户评论
当前暂无评价,快来发表您的观点吧...
更多相关好文
当前暂无更多相关好文推荐...
-
微信公众号文章/菜单添加小程序时路径如何获取? 2021-12-22
-
如何轻松获取微信小程序路径path? 2021-12-22
-
cannot import name 'CUDA_HOME' from 'mmcv.utils' 2021-12-05
-
vgg的loss一轮达到ln(1/n)阈值,如何解决 2021-11-21
-
如何下载使用utils库 2021-10-27
热门文章
-
cannot import name 'CUDA_HOME' from 'mmcv.utils' 2021-12-05
-
vgg的loss一轮达到ln(1/n)阈值,如何解决 2021-11-21
-
如何下载使用utils库 2021-10-27
-
pytorch调用geforce rtx 3060 2021-10-24
-
pytorch中如何实现多模型的并联 2021-09-15
栏目最新文章
公告提示
- pytorch中文文档
- pytorch官方文档
提交评论