我已经从https://github.com/XingxingZhang/dress下载了模型。
每个文件夹都有两个文件“ 16.state.t7”(〜5MB)和“ 16.t7”(〜100MB)。现在,我想在pytorch甚至是Torch中加载模型并开始推理。
当我使用以下方法在pytorch中加载“ 16.state.t7”模型时:-
from torch.utils.serialization import load_lua
model = load_lua('16.state.t7')
然后执行:-
type(model)
我收到“ torch.utils.serialization.read_lua_file.hashable_uniq_dict”
类似地,当我使用以下命令加载“ 16.t7”模型时:-
model = load_lua('16.t7')
然后执行:-
type(model)
它给出了“火炬张量”
我也是
print(load_lua('16.state.t7').keys())
我得到:-
['sgdParam', 'optimMethod', 'dst_vocab', 'validout', 'nhid', 'patience', lmPath', 'sariRevWeight', 'deltaSamplePos', 'initHidVal', 'testout', 'save', 'lrDiv', 'fineTuneFactor', 'useGPU', 'batchSize', 'src_vocab', 'gradClip', 'freqCut', 'savePerEpoch', 'ignoreCase', 'valid', 'lr', 'lmWeight', 'minLR', 'test', 'nivocab', 'learnZ', 'sampleStart', 'rfEpoch', 'novocab', 'validBatchSize', 'nin', 'power', 'sariWeight', 'dropout', 'normalizeUNK', 'encdecPath', 'nneg', 'train', 'initRange', 'recDropout', 'nlayers', 'saveBeforeLrDiv', 'lnZ', 'seqLen', 'curLR', 'minImprovement', 'wordEmbedding', 'seed', 'maxEpoch', 'attention', 'simPath', 'simWeight', 'model', 'embedOption']