Tensorflow数据集批处理复杂数据

时间:2018-03-24 18:55:49

标签: tensorflow iterator dataset batching

我尝试按照此链接中的示例进行操作:

https://www.tensorflow.org/programmers_guide/datasets

但我完全不知道如何运行会话。我理解第一个参数是要运行的操作,feed_dict是占位符(我的理解是训练或测试数据集的批次),

所以,这是我的代码:

batch_size = 100
handle_mix = tf.placeholder(tf.float64, shape=[])
handle_src0 = tf.placeholder(tf.float64, shape=[])
handle_src1 = tf.placeholder(tf.float64, shape=[])
handle_src2 = tf.placeholder(tf.float64, shape=[])
handle_src3 = tf.placeholder(tf.float64, shape=[])

我从mp4轨道和词干创建数据集,读取混合物和源大小,并填充它们以适合批处理

dataset = tf.data.Dataset.from_tensor_slices(
    {"x_mixed":padded_lbl, "y_src0": padded_src[0], "y_src1":      
    padded_src[1],"y_src2": padded_src[1], "y_src3": padded_src[1]})
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

从我应该做的例子:

next_element = iterator.get_next()

training_init_op = iterator.make_initializer(dataset)
for _ in range(20):
    # Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)

但是,我有丢失,摘要和优化操作,需要将数据作为批处理提供,其他示例如下:

l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict=    {handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

所以我想的是:

batch_mix,batch_src0,batch_src1,batch_src2,batch_src3 = data.train.next_batch(batch_size) 或者可以先单独运行以获取批次,然后按上述方式运行优化,例如:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = sess.run(next_element)
l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

最后一次尝试,返回在tf.data.Dataset.from_tensor_slices(“x_mixed”,“y_src0”,...等)中创建的批次的字符串名称,并且未能在会话中强制转换为tf.float64占位符

请告诉我如何创建此数据集,首先可能是张量切片中的结构错误,然后是如何批处理它们,

非常感谢,

1 个答案:

答案 0 :(得分:2)

问题是您在从张量切片创建数据集时将数据打包到dict中。这将导致$(document).ready(function() { function updatePrice() { var price = parseFloat($('#users-count').val()); var total = 0; if (price <= '200') { total = price * 4; } else if (price >= '201' && price <= '500') { total = 800 + (price - 200) * 1; } else if (price >= '501' && price <= '1000') { total = 1000 + (price - 500) * 0.5; } else if (price >= '1000') { total = 1250 + (price - 1000) * 0.25; } total = total ? total.toFixed(0) : 0; $('#users-price').val(total); } $(document).on('change, keyup', '#users-count', updatePrice); });将每个批次作为dict返回。如果我们做类似

的事情
<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>
<div class="calculator-price-field">
  <input type="number" id="users-count" class="calculator-field" value="1" min="1"><span class="calculator-field-label">users</span>
</div>
<div class="calculator-price-field">
  <input type="number" id="users-price" class="calculator-field" value="1" min="1"><span class="calculator-field-label">USD/month</span>
</div>

我们得到iterator.get_next()d = {"a": 1, "b": 2} k1, k2 = d (或者由于无序的dict键而反过来)。也就是说,您尝试解压缩k1 == "a"的结果只会为您提供dict键,而您对dict (张量)感兴趣。这应该工作:

k2 == "b"

如果您基于变量sess.run(next_element)等构建模型,它应该可以正常工作。请注意,使用next_element = iterator.get_next() x_mixed = next_element["x_mixed"] y_src0 = next_element["y_src0"] ... API,您不需要占位符! Tensorflow会看到您的模型输出需要例如x_mixed来自tf.data,因此每当您尝试x_mixed丢失函数/优化器等时,它都会执行此操作。如果您对占位符更加满意,当然可以继续使用它们,只需记住正确解开dict。这应该是正确的:

iterator.get_next()