我有一个t=[[0.2],[0.8],[0.4]]
之类的张量。
,并希望将每个value>0.8
映射到1
,如果在该阈值以下,则映射到0
。
我用tf.where
进行了尝试,但是它不起作用。
tf.where(t>0.8,tf.ones(t),tf.zeros_like(t))
代码:
self.temp_sim = tf.where(self.predictions>0.8,tf.ones(self.predictions),tf.zeros_like(self.predictions))
错误:
File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1700, in ones
output = fill(shape, constant(one, dtype=dtype), name=name)
File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 3468, in fill
"Fill", dims=dims, value=value, name=name)
File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 609, in _apply_op_helper
param_name=input_name)
File "C:\Users\\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 60, in _SatisfiesTypeConstraint
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: Value passed to parameter 'dims' has DataType float32 not in list of allowed values: int32, int64
将其格式化为代码无效!