tf.device的用法

时间:2019-06-13 01:59:34

标签: python tensorflow deep-learning

我正在定义一个自定义GRUcell,下面是一个表象范例。

第一个代码块定义了grucell,权重参数是通过device('/ cpu:0')定义的:

{1=>[2, 3], 3=>[4, 6], 5=>[3]}

第二个代码块包括一些利用已定义的grucell的操作

class GRUCell(RNNCell):
    def __init__(self, input_size, hidden_size, activation = tf.tanh, init_device = '/cpu:0', dtype = tf.float32, reuse = tf.AUTO_REUSE):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation
        self.dtype = dtype

        with tf.variable_scope('grucell'):
            with tf.device(init_device):
                w = tf.get_variable('w', [self.input_size, self.hidden_size], self.dtype)
                b = ~~               

如果运行上述代码build_model会发生什么?

如果我不将GRUCell的设备指定为'/ cpu:0'怎么办?

1 个答案:

答案 0 :(得分:0)

它们将被放置在设备上,由上游最近的作用域指定。 GRU-在GPU上,其他所有内容-在CPU上

import tensorflow as tf

with tf.device('cpu'):
  x = tf.zeros([1])
  y = tf.zeros([1])
  print(x, y)
  sm = x+y
  with tf.device('gpu'):
    z = tf.zeros([1])
    ml = sm*z


with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
  sess.run(ml)

输出:

add: (Add): /job:localhost/replica:0/task:0/device:CPU:0
2019-06-13 07:14:36.132810: I tensorflow/core/common_runtime/placer.cc:1059] add: (Add)/job:localhost/replica:0/task:0/device:CPU:0
mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
2019-06-13 07:14:36.132821: I tensorflow/core/common_runtime/placer.cc:1059] mul: (Mul)/job:localhost/replica:0/task:0/device:GPU:0
zeros: (Const): /job:localhost/replica:0/task:0/device:CPU:0
2019-06-13 07:14:36.132827: I tensorflow/core/common_runtime/placer.cc:1059] zeros: (Const)/job:localhost/replica:0/task:0/device:CPU:0
zeros_1: (Const): /job:localhost/replica:0/task:0/device:CPU:0
2019-06-13 07:14:36.132833: I tensorflow/core/common_runtime/placer.cc:1059] zeros_1: (Const)/job:localhost/replica:0/task:0/device:CPU:0
zeros_2: (Const): /job:localhost/replica:0/task:0/device:GPU:0
2019-06-13 07:14:36.132838: I tensorflow/core/common_runtime/placer.cc:1059] zeros_2: (Const)/job:localhost/replica:0/task:0/device:GPU:0