如何在tensorflow中使用tf.data.Dataset api跳过或将批处理设置为0

时间:2018-06-18 15:39:52

标签: python tensorflow tensorflow-datasets

以下是代码:

import tensorflow as tf
import numpy as np
import time

index1 = tf.Variable(-1, dtype=tf.int32, trainable=False)
index2 = tf.Variable(0, dtype=tf.int32, trainable=False)

starting_point1 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point2 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point3 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point4 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point5 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)

def starting_point_add_1_0(starting_point1, n):
    assignment1 = tf.assign_add(starting_point1, n)
    assignment2 = tf.assign_add(starting_point1, 0)
    return assignment1, assignment2

mod_op1 = tf.mod(index1, 5)
mod_op1 = tf.Print(input_=mod_op1, data=[mod_op1], message="mod_op1")

mod_op2 = tf.mod(index2, 5)
mod_op2 = tf.Print(input_=mod_op2, data=[mod_op2], message="mod_op2")

condition = tf.logical_and(tf.equal(mod_op1, 0), tf.equal(mod_op2, 1))
condition = tf.Print(input_=condition, data=[condition], message="condition")

ass1_1, ass1_2 = tf.cond(condition, 
                         lambda: starting_point_add_1_0(starting_point1, 1),
                         lambda: starting_point_add_1_0(starting_point1, 0))

# ass2 = tf.cond(tf.cast(index1 % 5 == 1 or index2 % 5 == 2, dtype=tf.bool), 
#               lambda: tf.assign_add(starting_point2, 1),
#               lambda: tf.assign_add(starting_point2, 0))

# ass3 = tf.cond(index1 % 5 == 2 or index2 % 5 == 3, 
#               lambda: tf.assign_add(starting_point3, 1),
#               lambda: tf.assign_add(starting_point3, 0))

# ass4 = tf.cond(index1 % 5 == 3 or index2 % 5 == 4, 
#               lambda: tf.assign_add(starting_point4, 1),
#               lambda: tf.assign_add(starting_point4, 0))

# ass5 = tf.cond(index1 % 5 == 4 or index2 % 5 == 0, 
#               lambda: tf.assign_add(starting_point5, 1),
#               lambda: tf.assign_add(starting_point5, 0))

data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
data2 = tf.data.Dataset.range(21, 40).skip(starting_point2)
data3 = tf.data.Dataset.range(41, 60).skip(starting_point3)
data4 = tf.data.Dataset.range(61, 80).skip(starting_point4)
data5 = tf.data.Dataset.range(81, 100).skip(starting_point5)

iterator1 = data1.make_initializable_iterator()
iterator2 = data2.make_initializable_iterator()
iterator3 = data3.make_initializable_iterator()
iterator4 = data4.make_initializable_iterator()
iterator5 = data5.make_initializable_iterator()

d1 = iterator1.get_next()
d2 = iterator2.get_next()
d3 = iterator3.get_next()
d4 = iterator4.get_next()
d5 = iterator5.get_next()

data_ = tf.stack((d1, d2, d3, d4, d5), axis=0)

ass6 = tf.assign_add(index1, 1)
ass7 = tf.assign_add(index2, 1)

with tf.control_dependencies([ass6, ass7]):
    data = tf.gather_nd(data_, indices=[[index1 % 5], [index2 % 5]])

init_op = tf.global_variables_initializer()


with tf.Session() as sess:
    sess.run(init_op)
    sess.run(iterator1.initializer)
    sess.run(iterator2.initializer)
    sess.run(iterator3.initializer)
    sess.run(iterator4.initializer)
    sess.run(iterator5.initializer)

    try:
        for i in range(20):
            t1, t2, t3, s1 = sess.run([data, index1, index2, starting_point1])
            print(t1, t2, t3, ".....", s1)
            sess.run([ass1_1, ass1_2])

    except tf.errors.OutOfRangeError:
        print("error")

因此,此代码的主要目的是能够使用tf.data.Dataset.range函数迭代我创建的5个不同数据集。我想从第一个数据集中获取1个元素,从第二个数据集中获取1个元素。那么,我想考虑data2data3,那么data3data4,那么data4data5,那么data5data1等等。

我已经考虑过这种方式了,这是输出:

[ 1 21] 0 1 ..... 0
[22 42] 1 2 ..... 1
[43 63] 2 3 ..... 1
[64 84] 3 4 ..... 1
[85  5] 4 5 ..... 1
[ 6 26] 5 6 ..... 1
[27 47] 6 7 ..... 2
[48 68] 7 8 ..... 2
[69 89] 8 9 ..... 2
[90 10] 9 10 ..... 2
[11 31] 10 11 ..... 2
[32 52] 11 12 ..... 3
[53 73] 12 13 ..... 3
[74 94] 13 14 ..... 3
[95 15] 14 15 ..... 3
[16 36] 15 16 ..... 3
[37 57] 16 17 ..... 4
[58 78] 17 18 ..... 4
[79 99] 18 19 ..... 4
error

现在我只测试第一个数据集的起点,看看它是否正常工作,但我得到了上述结果。我以为我会得到一个2而不是5,但事实并非如此。好像跳过几乎没有效果。

我想如果有人可以帮助我解决这个问题。

最后,我正在研究这个,因为我有一个不适合GPU内存的自定义数据集,因此我希望实现类似的东西。

我为skip()函数尝试了以下测试,结果如下:

starting_point1 = tf.Variable(tf.constant(2, dtype=tf.int64), trainable=False)

data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
iterator1 = data1.make_initializable_iterator()
d1 = iterator1.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator1.initializer)

    try:
        for i in range(10):
            print(sess.run(d1))
    except tf.errors.OutOfRangeError:
        print("error")

输出:

3
4
5
6
7
8
9
10
11
12

因此,如何强制迭代器在每个训练步骤中跳过?

非常感谢任何帮助!!

0 个答案:

没有答案