如何在Python中加载火炬预训练模型?

时间:2017-09-16 18:58:51

标签: python numpy deep-learning torch pytorch

我有一个Alexch的二值化预训练模型的火炬.t7文件。我需要将此.t7文件转换为npz / h5文件,以便我可以轻松导入。

这是allenAI的alexnet_XNOR。

这是用于读取文件的代码:https://github.com/bshillingford/python-torchfile/blob/master/torchfile.py

import torchfile as T
o = T.load('alexnet_XNOR.t7')

这是出现的错误:

File "<ipython-input-5-7ef550de1de7>", line 1, in <module>
runfile('C:/Users/yash/Desktop/importer.py', wdir='C:/Users/yash/Desktop')

File "C:\Program Files\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 866, in runfile
execfile(filename, namespace)

File "C:\Program Files\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)

File "C:/Users/yash/Desktop/importer.py", line 10, in <module>
o = T.load('alexnet_XNOR.t7')

File "C:\Users\yash\Desktop\torchreader.py", line 430, in load
return reader.read_obj()

File "C:\Users\yash\Desktop\torchreader.py", line 377, in read_obj
obj._obj = self.read_obj()

File "C:\Users\yash\Desktop\torchreader.py", line 394, in read_obj
obj[k] = v

TypeError: unhashable type: 'numpy.ndarray'

我的尝试: 我修改了火炬手并运行了主代码(添加了打印语句):

for _ in range(size):
            k = self.read_obj()
            v = self.read_obj()
            print(obj.shape)
            print(v.shape)
            obj[k] = v
            if self.use_list_heuristic:
                if not isinstance(k, int) or k <= 0:
                    keys_natural = False
                elif isinstance(k, int):
                    key_sum += k

对于obj.shape说没有,对于v.shape,它说:

AttributeError:'bytes'对象没有属性'shape'

print(obj)
print("===========")
print(v)

当我使用上面的代码片段打印obj和v时,我得到的就是我最初得到的错误:

{}
=================
b'torch.CudaTensor'
{b'_type': b'torch.CudaTensor'}
=================
[]
{b'output': array([], dtype=float32), b'_type': b'torch.CudaTensor'}
=================
b'gradInput'
{None: b'gradInput', b'output': array([], dtype=float32), b'_type': b'torch.CudaTensor'}
=================
[ 0.]

有人可以解释为什么会发生这种错误吗?我无法理解代码库本身,因为我之前没有使用过火炬。

谢谢!

0 个答案:

没有答案