Tensorflow:TypeError:int()参数必须是字符串,类字节对象或数字,而不是' Tensor'

时间:2018-05-15 20:01:47

标签: python tensorflow

我得到一个张量作为参数,现在我想创建一些变量,我正在尝试tf.get_variables(我不想使用tf.Variable)

input=tensor_argument


sequence_length = tf.shape(inputs)[1]  # the length 

hidden_size = tf.shape(inputs)[2] # hidden size 

W_omega = tf.get_variable(name='w_omega',shape=[sequence_length,sequence_length],dtype=tf.float32,initializer=tf.random_uniform_initializer(-0.01,0.01))

当我运行此代码时,我收到此错误

TypeError: int() argument must be a string, a bytes-like object or a number, not 'Tensor'

我尝试使用tf.shape和tf.size()但没有任何工作,

示例:

import tensorflow as tf
import numpy as np


data=np.random.randint(0,10,[2,4,300])

tensor_va=tf.constant(data)

d=tf.shape(tensor_va)[1]
W_omega = tf.get_variable(name='a_omega', shape=[d,d], dtype=tf.float32,
                          initializer=tf.random_uniform_initializer(-0.01, 0.01))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(W_omega))

如何使用或将张量转换为int for get_variable?

1 个答案:

答案 0 :(得分:0)

在TF 2.3中提供解决方案,默认情况下,启用急切执行。要解决上述问题,您应该将dtensor转换为numpy

如下所示完成工作代码

%tensorflow_version 2.x
print(tf.__version__)

import tensorflow as tf
import numpy as np


data=np.random.randint(0,10,[2,4,300])

tensor_va=tf.constant(data)

d=tensor_va.numpy().shape [1]

W_omega = tf.compat.v1.get_variable(name='a_omega', shape=[d,d], dtype=tf.float32,
                          initializer=tf.random_uniform_initializer(-0.01, 0.01))

print(W_omega)

输出:

2.3.0
<tf.Variable 'a_omega:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0.00958031, -0.00010247,  0.00161975, -0.00098458],
       [ 0.00287828, -0.00026206, -0.00369554,  0.00339144],
       [-0.00205746, -0.00392851,  0.00062867,  0.00173409],
       [-0.00409856,  0.0012137 , -0.00762768, -0.00545195]],
      dtype=float32)>