如何有效使用tf.case?

时间:2018-10-25 09:16:33

标签: tensorflow

我想使用tf.case在我的网络中选择不同的权重,但是效率太低了! 例如:

import tensorflow as tf
from datetime import datetime

tf.reset_default_graph()

M_list = []
time0 = datetime.now()
LENGTH = 100

for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[10, 10], initializer=tf.constant_initializer(i)))

with tf.Session() as sess:
    Ma = tf.get_variable('Ma', shape=[10, 1000], initializer=tf.constant_initializer(1))
    choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
    case_set = [(tf.equal(choose_mat[i], 1), lambda i=i: tf.matmul(M_list[i], Ma)) for i in range(LENGTH)]
    Mo = tf.case(case_set)
    sess.run(tf.global_variables_initializer())

    time1 = datetime.now()
    create_time = time1 - time0
    print('create time: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))

    for i in range(LENGTH):
        CM = [0] * LENGTH
        CM[i] = 1
        mo = sess.run(Mo, feed_dict={choose_mat: CM})

    time2 = datetime.now()
    run_time = time2 - time1
    total_time = time2 - time0
    print('run time: ', str(run_time.seconds) + '.' + str(run_time.microseconds).zfill(6))
    print('total time: ', str(total_time.seconds) + '.' + str(total_time.microseconds).zfill(6))

结果:

create time:  23.969327
run time:  12.362408
total time:  36.331735

我知道tf.case将计算case_set中的所有分支。 所以我在计算matmul之前先选择权重,像这样:

import tensorflow as tf
from datetime import datetime

tf.reset_default_graph()

M_list = []
time0 = datetime.now()
LENGTH = 100

for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[10, 10], initializer=tf.constant_initializer(i)))


with tf.Session() as sess:
    choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
    case_set = [(tf.equal(choose_mat[i], 1), lambda i=i: M_list[i]) for i in range(LENGTH)]
    M = tf.case(case_set)

    Ma = tf.get_variable('Ma', shape=[10, 1000], initializer=tf.constant_initializer(1))
    Mo = tf.matmul(M, Ma)
    sess.run(tf.global_variables_initializer())

    time1 = datetime.now()
    create_time = time1 - time0
    print('create time2: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))

    for i in range(LENGTH):
        CM = [0] * LENGTH
        CM[i] = 1
        mo = sess.run(Mo, feed_dict={choose_mat: CM})

    time2 = datetime.now()
    create_time = time2 - time1
    total_time = time2 - time0
    print('run time: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))
    print('total time: ', str(total_time.seconds) + '.' + str(total_time.microseconds).zfill(6))

结果:

create time2:  23.321199
run time:  5.747378
total time:  29.068577

速度更快,似乎有效果,但仍然很慢。如果我们不使用tf.case,就像这样:

import tensorflow as tf
from datetime import datetime

tf.reset_default_graph()

M_list = []
time0 = datetime.now()
LENGTH = 100

for i in range(LENGTH):
    M_list.append(tf.get_variable('M'+str(i), shape=[10, 10], initializer=tf.constant_initializer(i)))

with tf.Session() as sess:
    Ma = tf.get_variable('Ma', shape=[10, 1000], initializer=tf.constant_initializer(1))
    choose_mat = tf.placeholder(tf.int32, shape=[LENGTH])
    Mo_list = [tf.matmul(M_list[i], Ma) for i in range(LENGTH)]

    sess.run(tf.global_variables_initializer())

    time1 = datetime.now()
    create_time = time1 - time0
    print('create time: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))

    for i in range(LENGTH):
        CM = [0] * LENGTH
        CM[i] = 1
        mo = sess.run(Mo_list[i], feed_dict={choose_mat: CM})

    time2 = datetime.now()
    run_time = time2 - time1
    total_time = time2 - time0
    print('run time: ', str(run_time.seconds) + '.' + str(run_time.microseconds).zfill(6))
    print('total time: ', str(total_time.seconds) + '.' + str(total_time.microseconds).zfill(6))

结果:

create time:  0.547081
run time:  0.596932
total time:  1.144013

因此,使用tf.case的运行时是没有tf.case的运行时的十倍以上! 但是我想按张量选择权重,那么如何有效地使用tf.case?还是存在更有效的方法?

非常感谢。

1 个答案:

答案 0 :(得分:0)

我找到了一种有效的方法,不要使用tf.case!太蠢了。

这种方法可以有效地按张量选择权重:

import tensorflow as tf
from datetime import datetime

tf.reset_default_graph()

time0 = datetime.now()
LENGTH = 100
M_list = [tf.get_variable('M'+str(i), shape=[10, 10, 1], initializer=tf.constant_initializer(i)) for i in range(LENGTH)]
M_concat = tf.concat(M_list, axis=2, name='M_concat')

with tf.Session() as sess:
    choose_mat = tf.placeholder(tf.int32, shape=[1])
    M = M_concat[:, :, choose_mat[0]]
    M = tf.squeeze(M)

    Ma = tf.get_variable('Ma', shape=[10, 1000], initializer=tf.constant_initializer(1))
    Mo = tf.matmul(M, Ma)

    sess.run(tf.global_variables_initializer())

    time1 = datetime.now()
    create_time = time1 - time0
    print('create time: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))

    for i in range(LENGTH):
        mo = sess.run(Mo, feed_dict={choose_mat: [i]})

    time2 = datetime.now()
    create_time = time2 - time1
    total_time = time2 - time0
    print('run time: ', str(create_time.seconds) + '.' + str(create_time.microseconds).zfill(6))
    print('total time: ', str(total_time.seconds) + '.' + str(total_time.microseconds).zfill(6))

结果:

create time:  0.540483
run time:  0.085812
total time:  0.626295

我不知道tf.case用于什么,因为它太慢了。但是我们确实需要if或切换到控制流,我认为它应该更有效。