我尝试使用bigdata2提供的代码在tensorflow CIFAR 10教程中测试单个图像,以便在链接中进行培训。 1 对CIFAR10_train.py
进行了以下更改def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
with tf.device('/cpu:0'):
images, labels = cifar10.distorted_inputs()
is_training = tf.placeholder(dtype=bool,shape=(),name='is_training')
imgs = tf.placeholder(tf.float32, (1, 32, 32, 3), name='imgs')
images = tf.cond(is_training, lambda:images, lambda:imgs)
logits = cifar10.inference(images)
已经做出的另一个改变是
import numpy as np
tmp_img = np.ndarray(shape=(1,32,32,3), dtype=float)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op, feed_dict={is_training: True, imgs: tmp_img})
我收到错误:
runfile(' C:/Users/User/new_train.py' ;, wdir =' C:/ Users / User') 在开始训练之前用20000 CIFAR图像填充队列。这将需要几分钟。
追踪(最近一次呼叫最后一次):
文件"",第1行,in runfile(' C:/Users/User/new_train.py' ;, wdir =' C:/ Users / User')
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ spyder \ utils \ site \ sitecustomize.py",710行in RUNFILE execfile(filename,namespace)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ spyder \ utils \ site \ sitecustomize.py",101行in的execfile exec(compile(f.read(),filename,' exec'),命名空间)
File" C:/Users/User/new_train.py" ;,第135行,在 tf.app.run()
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ platform \ app.py",第48行,in跑 _sys.exit(main(_sys.argv [:1] + flags_passthrough))
文件" C:/Users/User/new_train.py",第131行,主要 列车()
文件" C:/Users/User/new_train.py" ;,第77行,在火车上 logits = cifar10.inference(images)
文件" C:\ Users \ User \ cifar10.py",第246行,推论 stddev = 0.04,wd = 0.004)
文件" C:\ Users \ User \ cifar10.py",第135行,_variable_with_weight_decay tf.truncated_normal_initializer(stddev = stddev,dtype = dtype))
文件" C:\ Users \ User \ cifar10.py",第111行,_variable_on_cpu var = tf.get_variable(name,shape,initializer = initializer,dtype = dtype)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ ops \ variable_scope.py",第1203行, get_variable 约束=约束)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ ops \ variable_scope.py",1092行,in get_variable 约束=约束)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ ops \ variable_scope.py",第425行, get_variable 约束=约束)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ ops \ variable_scope.py",第394行, _true_getter use_resource = use_resource,constraint = constraint)
文件" C:\ Users \ User \ AppData \ Local \ conda \ conda \ envs \ tensorflow_windows \ lib \ site-packages \ tensorflow \ python \ ops \ variable_scope.py",第763行,在_get_single_variable "而是%s。" %(名称,形状))
ValueError:必须完全定义新变量的形状(local3 / weights),而是(?,384)。
有人可以提出纠正此错误的建议吗?