将此代码从tensorflow 1移植到tensorflow 2

时间:2020-01-22 16:35:19

标签: tensorflow tensorflow2.0

我正在尝试在stackoverflow.com上移植在其中一个答案中找到的代码:

import tensorflow as tf

x = tf.placeholder(tf.float32,shape=[3,3])

cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10]]

print(sess.run(y, feed_dict={x: sample}))

到目前为止,我已经完成了:

import tensorflow as tf

x = tf.keras.Input(shape=[3,3], dtype=tf.dtypes.float32)

cond1 = tf.where(x > 10, x - 10, tf.zeros_like(x))
cond2 = tf.where(x < 4, x + 60, tf.zeros_like(x))
cond3 = tf.where(tf.logical_and(x >= 4, x <= 10), x, tf.zeros_like(x))
y = cond1 + cond2 + cond3

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10]]

但是我找不到打印结果的方法,因为按照移植指南的建议,我无法执行print(f(sample)。

1 个答案:

答案 0 :(得分:0)

首先使用创建模型,

> dfout
  x1   x2
1  1    a
2  2 <NA>
3  3    a

然后做

from tensorflow.keras.models import Model
model = Model(x, y)

res = model.predict(sample) print(res) 将是一个numpy数组。