我需要将填充零填充到从TFRecord文件读取的张量(代码片段中的'ARRAY')。因为我使用的训练模型要求它们的形状应该相同。但是,我的输入有不同的宽度和长度。因此,我试图计算代码片段中应该填充的零的数量('paddings = tf.Variable([[0,targetLength],[0,targetWidth]])'。 但是,tensorflow引发了InvalidArgumentError,并且我的代码中从未出现过值“arg0”。
变量ARRAY的示例如下:
[1,0,1,0]
[2,0,2,0]
我应该将其填充到
[1,0,1,0,0,0]
[2,0,2,0,0,0]
[0,0,0,0,0,0]
可以有很多大型阵列,所以我想在训练前填上它。 这是我的代码片段。
def my_input_fn(file_path, perform_shuffle=True, repeat_count=1):
global width, length # two int64 variables.
batchNum = 32
def parse_ARRAY(tfrecord):
features = tf.parse_single_example(
tfrecord,
# Defaults are not specified since both keys are required.
features={
'label': tf.FixedLenFeature([], tf.int64),
'length': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'ARRAY': tf.FixedLenFeature([], tf.string)
})
ARRAY = tf.decode_raw(features['ARRAY'], tf.int64)
label = tf.cast(features['label'], tf.int64)
ARRAYLength = tf.cast(features['length'], tf.int64)
ARRAYWidth = tf.cast(features['width'], tf.int64)
ARRAYshape = tf.stack([ARRAYLength, ARRAYWidth])
ARRAY = tf.reshape(ARRAY, ARRAYshape)
TFWidth = tf.convert_to_tensor(width, tf.int64)
TFLength = tf.convert_to_tensor(length, tf.int64)
targetWidth = tf.subtract(TFWidth, ARRAYWidth)
targetLength = tf.subtract(TFLength, ARRAYLength)
paddings = tf.Variable([[0, targetLength],[0, targetWidth]])
with tf.Session() as sess:
sess.run(paddings.initializer)
tf.pad(ARRAY, paddings, "CONSTANT")
return {"ARRAY":ARRAY}, label
dataset = tf.data.TFRecordDataset(file_path)
dataset = dataset.map(parse_ARRAY)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # time
dataset = dataset.batch(batchNum) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
def run_tfr(args):
global length, width
model = tf.estimator.Estimator(model_fn)
model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)
e = model.evaluate(input_fn = lambda: my_input_fn(args[0]+".tests"))
print("Testing Accuracy:", e['accuracy'])
if __name__ == "__main__":
width, length = load_data.loadInfo(sys.argv[1])
# the usage is 'python thisfile.py file.pkl' or
# 'python thisfile.py file.tfrecord'
if sys.argv[1].endswith(".pkl"):
# handle a file from cPickle.
elif sys.argv[1].endswith("tfrecord"):
run_tfr(sys.argv[1:])
这是tensorflow的输出。
Caused by op u'arg0', defined at:
File "DLCNN.py", line 230, in <module>
run_tfr(sys.argv[1:])
File "DLCNN.py", line 215, in run_tfr
model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 708, in _train_model
input_fn, model_fn_lib.ModeKeys.TRAIN)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 577, in _get_features_and_labels_from_input_fn
result = self._call_input_fn(input_fn, mode)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 663, in _call_input_fn
return input_fn(**kwargs)
File "DLCNN.py", line 215, in <lambda>
model.train(input_fn = lambda: my_input_fn(args[0]), steps=num_steps)
File "DLCNN.py", line 149, in my_input_fn
dataset = dataset.map(parse_ARRAY)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 712, in map
return MapDataset(self, map_func)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1385, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 486, in add_to_graph
self._create_definition_if_needed()
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 334, in _create_definition_if_needed_impl
argholder = array_ops.placeholder(argtype, name=argname)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1599, in placeholder
return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3091, in _placeholder
"Placeholder", dtype=dtype, shape=shape, name=name)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/function.py", line 703, in create_op
**kwargs)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
op_def=op_def)
File "/home/fff000/Documents/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
[[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
顺便说一下,如果sess.run(paddings.initializer)
被以下代码替换,tensorflow将不会报告错误:
node1 = tf.constant(3.0, dtype=tf.float32)
node2 = tf.constant(4.0)
node3 = tf.add(node1, node2)
print("sess.run(node3):", sess.run(node3))
我也想知道是否还有其他方法来填充从TFRecord读取的数组。感谢。
答案 0 :(得分:0)
这个答案主要关注的问题&#34;我还想知道是否还有其他方法来填充从TFRecord读取的数组。&#34;
我使用tf.shape
提取数组的形状信息,然后填充数组。
这是代码。
def my_input_fn(file_path, perform_shuffle=True, repeat_count=1):
global width, length
batchNum = 32
def parse_ARRAY(tfrecord):
features = tf.parse_single_example(
tfrecord,
# Defaults are not specified since both keys are required.
features={
'label': tf.FixedLenFeature([], tf.int64),
'length': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'ARRAY': tf.FixedLenFeature([], tf.string)
})
ARRAY = tf.decode_raw(features['ARRAY'], tf.int64)
label = tf.cast(features['label'], tf.int64)
ARRAYLength = tf.cast(features['length'], tf.int64)
ARRAYWidth = tf.cast(features['width'], tf.int64)
ARRAYshape = tf.stack([ARRAYLength, ARRAYWidth])
ARRAY = tf.reshape(ARRAY, ARRAYshape)
height = tf.shape(ARRAY)[0]
imgWidth = tf.shape(ARRAY)[1]
paddings = [[0, length - height],[0, width - imgWidth]]
#ARRAY = array_ops.pad(ARRAY, paddings)
ARRAY = tf.pad(ARRAY, paddings, "CONSTANT")
return {"ARRAY":ARRAY}, label
......
实现了进程数组从TFRecord读取的目标。但是,仍然不知道InvalidArgumentError的原因。