我正在尝试从PyTorch导入一些VGG19自定义权重到Tensorflow。
我一直在使用TorchVision VGG模型架构,但令我惊讶的是,在将所有权重导入Tensorflow之后,在全连接层中没有得到相同的结果,但是在CNN上却得到了相同的结果。
一段时间后,我意识到在PyTorch的文档中使用的是torch.Tensor.view
,而我使用的是tf.layers.Flatten()
因此,我相信从here可以理解,连续数组存在问题。然后,我尝试使用torch.flatten()
而不是torch.Tensor.view
,这在两个框架中都得到了相同的结果。
我的问题是,由于我的模型是用view
而不是flatten
训练的,是否有任何方法可以在Tensorflow中重现view
的相同效果?实际上view
在做什么而其他方法却没有?
谢谢!