使用TensorflowJS转换器时,tf.gather中的“ Axis”参数无效

时间:2018-10-26 17:12:08

标签: tensorflow.js tensorflowjs-converter

我正在Python中使用Tensorflow定义和训练模型,然后将其保存到save_model并使用TensorflowJS将其加载到网站上。 axis参数在gather op中根本没有任何作用,无论如何都为0,但是在Python应用程序中却如此。我在下面做了一个最小的例子。我正在使用Windows 7 x64,Python 3.5.2,numpy 1.15.1(在安装tensorflowjs时强制执行-在此之前已安装15.2)和仅使用pip install tensorflowjs安装的最新版本的tensorflowjs。该浏览器是Chrome,并且全部使用Wamp在localhost上提供。

第1步:Python

# coding=utf-8

import tensorflow as tf
import sys
import os, shutil

tf.reset_default_graph()

x = tf.placeholder(tf.int32, shape=[2, 2], name='x')

res = tf.gather(x, tf.constant([0], dtype="int32"), axis=1) # it gathers rows, not columns, despite axis=1
res = tf.identity(res, name="y")

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    print(sess.run(res, feed_dict={x: [[17,20],[4,44]]}))

    tf.saved_model.simple_save(sess, "export", inputs={"x": x}, outputs={"y": res})

第2步:tfjs-converter命令

tensorflowjs_converter --input_format=tf_saved_model --output_node_names="y" --saved_model_tags=serve export C:/wamp/www/avatar/web_model

第3步:服务网站:

<!-- index.html -->

<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml">

<head>
  <meta http-equiv="Content-Type" content="text/html" charset="utf-8" />
  <title>test</title>
  <script src="./ext/jquery-3.3.1.min.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>

<body>
  <script src="./script.js"></script>
</body>

</html>
// script.js

async function learnLinear() {

const model = await tf.loadFrozenModel('http://localhost/avatar/web_model/tensorflowjs_model.pb', 'http://localhost/avatar/web_model/weights_manifest.json');

  x = tf.tensor2d([[17,20],[4,44]], [2,2], "int32");

  console.log(model.execute({"x": x}).print());
}
learnLinear();                                   

因此,axis都设置为0和1,gather返回Web应用程序中的第一行,而不是第一列。但是在Python Tensorflow中,它可以按预期工作,并且在运行Python脚本时会打印第一列。

当然,我总是可以转置然后执行gather,但这是该操作的关键功能,如果尚未真正包含它,则可能有兴趣实现它。 ..

感谢您的帮助!

0 个答案:

没有答案