我尝试使用train_image_classifier.py从slimnet with nasnet_mobile训练NASNet-A_Mobile_224进行两级分类,但是我得到错误
TypeError: separable_convolution2d() got an unexpected keyword argument 'data_format'
我怀疑新的NASNet需要TF版本1.4。有人可以证实这一点吗?我正在使用Tensorflow 1.3。
下面给出了更广泛的错误:
Traceback (most recent call last):
File "train_image_classifier.py", line 574, in <module>
tf.app.run()
File "/home/sami/virenv/tensorflow_vanilla/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "train_image_classifier.py", line 474, in main
clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
File "/home/sami/projects/Tools/models/research/slim/deployment/model_deploy.py", line 193, in create_clones
outputs = model_fn(*args, **kwargs)
File "train_image_classifier.py", line 457, in clone_fn
logits, end_points = network_fn(images)
File "/home/sami/projects/Tools/models/research/slim/nets/nets_factory.py", line 135, in network_fn
return func(images, num_classes, is_training=is_training, **kwargs)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet.py", line 371, in build_nasnet_mobile
final_endpoint=final_endpoint)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet.py", line 450, in _build_nasnet_base
net, cell_outputs = stem()
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet.py", line 445, in <lambda>
stem = lambda: _imagenet_stem(images, hparams, stem_cell)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet.py", line 264, in _imagenet_stem
cell_num=cell_num)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet_utils.py", line 326, in __call__
stride, original_input_left)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet_utils.py", line 352, in _apply_conv_operation
net = _stacked_separable_conv(net, stride, operation, filter_size)
File "/home/sami/projects/Tools/models/research/slim/nets/nasnet/nasnet_utils.py", line 183, in _stacked_separable_conv
stride=stride)
File "/home/sami/virenv/tensorflow_vanilla/local/lib/python2.7/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 181, in func_with_args
return func(*args, **current_args)
TypeError: separable_convolution2d() got an unexpected keyword argument 'data_format'
答案 0 :(得分:1)
是的,它必须是tensorflow 1.4.0