我正在尝试通过使用py_func
来创建带有输入管道的{h1}包装器来映射.h5解析器函数。我想在map函数中传递两个参数:Dataset.map()
和filename
。以下代码具有调用顺序:window_size
-> Dataset.map
-> _pyfn_wrapper
缺点是使用map()函数时_pyfn_wrapper只能接受一个参数,因为parse_h5
不能压缩2种类型的数据:字符串然后是int
from_tensor_slices
首先可以运行以下代码段,然后首先创建随机数据
def helper(window_size, batch_size, ncores=mp.cpu_count()):
flist = []
for dirpath, _, fnames in os.walk('./'):
for fname in fnames:
flist.append(os.path.abspath(os.path.join(dirpath, fname)))
f_len = len(flist)
# init list of files
batch = tf.data.Dataset.from_tensor_slices((tf.constant(flist))) #fixme: how to zip one list of string and a list of int
batch = batch.map_fn(_pyfn_wrapper, num_parallel_calls=ncores) #fixme: how to map two args
batch = batch.shuffle(batch_size).batch(batch_size, drop_remainder=True).prefetch(ncores + 6)
# construct iterator
it = batch.make_initializable_iterator()
iter_init_op = it.initializer
# get next img and label
X_it, y_it = it.get_next()
inputs = {'img': X_it, 'label': y_it, 'iterator_init_op': iter_init_op}
return inputs, f_len
def _pyfn_wrapper(filename): #fixme: args
# filename, window_size = args #fixme: try to separate args
window_size = 100
return tf.py_func(parse_h5, #wrapped pythonic function
[filename, window_size],
[tf.float32, tf.float32] #[input, output] dtype
)
def parse_h5(name, window_size):
with h5py.File(name.decode('utf-8'), 'r') as f:
X = f['X'][:].reshape(window_size, window_size, 1)
y = f['y'][:].reshape(window_size, window_size, 1)
return X, y
# create tf.data.Dataset
helper, f_len = helper(100, 5, True)
# inject into model
with tf.name_scope("Conv1"):
W = tf.get_variable("W", shape=[3, 3, 1, 1],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable("b", shape=[1], initializer=tf.contrib.layers.xavier_initializer())
layer1 = tf.nn.conv2d(helper['img'], W, strides=[1, 1, 1, 1], padding='SAME') + b
logits = tf.nn.relu(layer1)
loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=helper['label'], predictions=logits))
train_op = tf.train.AdamOptimizer(0.0001).minimize(loss)
# session
with tf.Session() as sess:
sess.run(helper['iterator_init_op'])
sess.run(tf.global_variables_initializer())
for step in range(f_len):
sess.run([train_op])
答案 0 :(得分:0)
使用Datasets
的嵌套结构作为@Sharky的注释是解决方案之一。为了避免出现错误,应该解压缩最后一个嵌套的args parse_h5
函数而不是_pyfn_wrapper
:
TypeError:仅在渴望执行时,张量对象才可迭代 已启用。要遍历此张量,请使用tf.map_fn。
还应该解码该参数,因为传递tf.py_func()args会转换为二进制文字。
代码已修改:
def helper(...):
...
flist.append((os.path.abspath(os.path.join(dirpath, fname)), str(window_size)))
...
def _pyfn_wrapper(args):
return tf.py_func(parse_h5, #wrapped pythonic function
[args],
[tf.float32, tf.float32] #output dtype
)
def parse_h5(args):
name, window_size = args #only unzip the args here
window_size = int(window_size.decode('utf-8')) #and decode for converting bin to int
with h5py.File(name, 'r') as f:
...