如何使用布尔掩码定义TFLearn自定义目标函数错误?

时间:2018-11-09 01:35:42

标签: python tensorflow neural-network tflearn

我正在尝试定义自定义损失函数

def my_objective(y_pred, y_true):
  pred_slice = tf.slice(y_pred, [0,0], [-1,1])
  true_slice = tf.slice(y_true, [0,0], [-1,1])
  mask = np.array(d['mask_'+str(0)], dtype=bool)
  masked_pred = tf.boolean_mask(pred_slice, mask)
  masked_true = tf.boolean_mask(true_slice, mask)
  return tf.reduce_mean(tf.square(masked_pred - masked_true))

...
net = tflearn.regression(net, optimizer='adam', loss=my_objective)
model = tflearn.DNN(net)
model.fit(Xtrain, ytrain)

但是尝试训练时收到以下错误:

  

---------------------------------运行ID:H57ZXZ日志目录:/ tmp / tflearn_logs /      

---------------------------------训练样本:13136验证样本:0

     

-------------------------------------------------- ---------------------------- InvalidArgumentError追溯(最近的调用   持续)   /opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_do_call(self,fn,* args)1333中尝试:   -> 1334返回fn(* args)1335,但error.OpError为e:

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_run_fn中(feed_dict,fetch_list,target_list,选项,run_metadata)   (1318)   -> 1319选项,feed_dict,fetch_list,target_list,run_metadata)1320

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_call_tf_sessionrun(自身,选项,feed_dict,fetch_list,   target_list,run_metadata)1406 self._session,选项,   feed_dict,fetch_list,target_list,   -> 1407 run_metadata)1408

     

InvalidArgumentError:indexes [0] = 66不在[0,64)中[[{{node   boolean_mask_1 / GatherV2}} = GatherV2 [Taxis = DT_INT32,   Tindices = DT_INT64,Tparams = DT_FLOAT,   _device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](boolean_mask_1 / Reshape,   boolean_mask / Squeeze,boolean_mask / concat / axis)]]

     

在处理上述异常期间,发生了另一个异常:

     

InvalidArgumentError跟踪(最近的调用)   最后)         1个模型= tflearn.DNN(net)   ----> 2 model.fit(Xtrain,ytrain)

     

/opt/conda/lib/python3.6/site-packages/tflearn/models/dnn.py在   fit(self,X_inputs,Y_targets,n_epoch,validation_set,show_metric,   batch_size,随机播放,snapshot_epoch,snapshot_step,excl_trainops,   validate_batch_size,run_id,回调)       214 excl_trainops = excl_trainops,       215 run_id = run_id,   -> 216个callbacks = callbacks)       217       218 def fit_batch(self,X_inputs,Y_targets):

     

/opt/conda/lib/python3.6/site-packages/tflearn/helpers/trainer.py在   适合(自己,feed_dicts,n_epoch,val_feed_dicts,show_metric,   snapshot_step,snapshot_epoch,shuffle_all,dprep_dict,daug_dict,   excl_trainops,run_id,回调)       337(bool(self.best_checkpoint_path)| snapshot_epoch),       第338章   -> 339 show_metric)       340       341#更新训练状态

     

/opt/conda/lib/python3.6/site-packages/tflearn/helpers/trainer.py在   _train(自我,training_step,snapshot_epoch,snapshot_step,show_metric)       816 tflearn.is_training(真实,会话= self.session)       817_,train_summ_str = self.session.run([self.train,self.summ_op],   -> 818 feed_batch)       819       820#从摘要字符串中检索损失值

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在运行中(自我,获取,feed_dict,选项,run_metadata)       927尝试:       (928)第928章   -> 929 run_metadata_ptr)       930如果run_metadata:       931 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_run(自身,句柄,访存,feed_dict,选项,run_metadata)中
  1150如果final_fetches或final_targets或(句柄和   feed_dict_tensor):1151个结果= self._do_run(句柄,   final_targets,final_fetches,   -> 1152 feed_dict_tensor,选项,run_metadata)1153其他:1154结果= []

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_do_run(自身,句柄,target_list,fetch_list,feed_dict,选项,   run_metadata)1326如果句柄为None:1327返回   self._do_call(_run_fn,提要,提取,目标,选项,   -> 1328 run_metadata)1329否则:1330返回self._do_call(_prun_fn,句柄,提要,获取)

     

/opt/conda/lib/python3.6/site-packages/tensorflow/python/client/session.py   在_do_call(self,fn,* args)中1346通过1347
  消息= error_interpolation.interpolate(消息,self._graph)   -> 1348提高类型(e)(node_def,op,消息)1349 1350 def _extend_graph(自身):

     

InvalidArgumentError:索引[0] = 66不在[0,64)中[[node   boolean_mask_1 / GatherV2(定义为:6)   = GatherV2 [出租汽车= DT_INT32,丁迪克斯= DT_INT64,Tparams = DT_FLOAT,_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]](boolean_mask_1 / Reshape,   boolean_mask / Squeeze,boolean_mask / concat / axis)]]

     

由op'boolean_mask_1 / GatherV2'引起,在以下位置定义:File   _run_module_as_main中的“ /opt/conda/lib/python3.6/runpy.py”,第193行       “ 主要”,mod_spec)文件“ /opt/conda/lib/python3.6/runpy.py”,第85行,_run_code       exec(代码,run_globals)文件“ /opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py”,行   16,在       app.launch_new_instance()文件“ /opt/conda/lib/python3.6/site-packages/traitlets/config/application.py”,   658行,在launch_instance中       app.start()文件“ /opt/conda/lib/python3.6/site-packages/ipykernel/kernelapp.py”,行   505,开始时       self.io_loop.start()文件“ /opt/conda/lib/python3.6/site-packages/tornado/platform/asyncio.py”,   第132行,开始时       self.asyncio_loop.run_forever()文件“ /opt/conda/lib/python3.6/asyncio/base_events.py”,第422行,在   run_forever       self._run_once()文件“ /opt/conda/lib/python3.6/asyncio/base_events.py”,第1434行,在   _run_once       handle._run()在_run中的文件“ /opt/conda/lib/python3.6/asyncio/events.py”,第145行       self._callback(* self._args)文件“ /opt/conda/lib/python3.6/site-packages/tornado/ioloop.py”,第758行,   在_run_callback中       ret = callback()文件“ /opt/conda/lib/python3.6/site-packages/tornado/stack_context.py”,   第300行,在null_wrapper中       返回fn(* args,** kwargs)文件“ /opt/conda/lib/python3.6/site-packages/tornado/gen.py”,行1233,在   内       self.run()文件“ /opt/conda/lib/python3.6/site-packages/tornado/gen.py”,行1147,在   跑       产生= self.gen.send(value)文件“ /opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py”,行   357,在process_one中       产生gen.maybe_future(dispatch(* args))文件“ /opt/conda/lib/python3.6/site-packages/tornado/gen.py”,行326,在   包装纸       yield =下一个(结果)文件“ /opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py”,行   267,在dispatch_shell中       产生gen.maybe_future(handler(stream,idents,msg))文件“ /opt/conda/lib/python3.6/site-packages/tornado/gen.py”,第326行,在   包装纸       yield =下一个(结果)文件“ /opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py”,行   534,在execute_request中       user_expressions,allow_stdin,文件“ /opt/conda/lib/python3.6/site-packages/tornado/gen.py”,第326行,在   包装纸       yield =下一个(结果)文件“ /opt/conda/lib/python3.6/site-packages/ipykernel/ipkernel.py”,行   294,在do_execute中       res = shell.run_cell(代码,store_history = store_history,silent =静音)文件   “ /opt/conda/lib/python3.6/site-packages/ipykernel/zmqshell.py”,行   536,在run_cell中       返回super(ZMQInteractiveShell,self).run_cell(* args,** kwargs)文件   “ /opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py”,   第2817行,在run_cell中       raw_cell,store_history,silent,shell_futures)文件“ /opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py”,   _run_cell中的第2843行       returnRunner(coro)文件“ /opt/conda/lib/python3.6/site-packages/IPython/core/async_helpers.py”,   _pseudo_sync_runner中的第67行       coro.send(None)文件“ /opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py”,   第3018行,在run_cell_async中       交互性=交互性,编译器=编译器,结果=结果)文件   “ /opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py”,   第3183行,在run_ast_nodes中       如果(从self.run_code(代码,结果)产生):文件“ /opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py”,   第3265行,在run_code中       exec(code_obj,self.user_global_ns,self.user_ns)文件“”,第7行,在       净= tflearn.regression(净,优化程序='adam',损失= my_objective)文件   “ /opt/conda/lib/python3.6/site-packages/tflearn/layers/estimator.py”,   第178行,回归       损失=损失(传入,占位符)文件“”,my_objective中的第6行       masked_true = tf.boolean_mask(true_slice,mask)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py”,   第1204行,在boolean_mask中       返回_apply_mask_1d(张量,掩码,轴)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py”,   _apply_mask_1d中的第1174行       返回聚集(reshaped_tensor,索引,轴=轴)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py”,   2675行,聚集       返回gen_array_ops.gather_v2(参数,索引,轴,名称=名称)文件   “ /opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py”,   第3332行,在collect_v2中       “ GatherV2”,参数=参数,索引=索引,轴=轴,名称=名称)文件   “ /opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py”,   _apply_op_helper中的第787行       op_def = op_def)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py”,   第488行,在new_func中       返回func(* args,** kwargs)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”,   第3274行,在create_op中       op_def = op_def)文件“ /opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”,   第1770行,在 init 中       self._traceback = tf_stack.extract_stack()

     

InvalidArgumentError(请参阅上面的回溯):indexs [0] = 66不是   在[0,64)中[[node boolean_mask_1 / GatherV2(在   :6)= GatherV2 [Taxis = DT_INT32,   Tindices = DT_INT64,Tparams = DT_FLOAT,   _device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”](boolean_mask_1 / Reshape,   boolean_mask / Squeeze,boolean_mask / concat / axis)]]

0 个答案:

没有答案