在张量流中重新分配不变张量

时间:2019-04-16 03:12:53

标签: python tensorflow

我有一个要求,我想使用x的更新值作为RNN的输入。下面的代码段可能会详细说明您。

x = tf.placeholder("float", shape=[None,1])
RNNcell = tf.nn.rnn_cell.BasicRNNCell(....)
outputs, _ = tf.dynamic_rnn(RNNCell, tf.reshape(x, [1,-1,1]))
x = outputs[-1] * (tf.Varaibles(...) * tf.Constants(...)) 

2 个答案:

答案 0 :(得分:1)

@Vlad的答案是正确的,但由于新成员无法投票。以下代码段是带有RNN单元的Vlads one的更新版本。

$url = "https://api.weather.gov/products/fead3465-2e6f-4350-ae90-15aaa61b91ff"  
$WebRequest = [System.Net.WebRequest]::Create($url)
$WebRequest.Method = "GET"
$WebRequest.ContentType = "application/json"
$WebRequest.UseDefaultCredentials = $true
$Response = $WebRequest.GetResponse()
$ResponseStream = $Response.GetResponseStream()
$ReadStream = New-Object System.IO.StreamReader $ResponseStream
$data = $ReadStream.ReadToEnd()

[System.Reflection.Assembly]::LoadWithPartialName("System.Web.Extensions")
$ser = New-Object System.Web.Script.Serialization.JavaScriptSerializer
$json = $ser.DeserializeObject($data)
echo $json

答案 1 :(得分:0)

这个例子或多或少是不言自明的。我们获取模型的输出,将其乘以某个张量(可以是标量,或者可以广播的秩为> 0的张量),再次将其输入模型并得到结果:

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, shape=(None, 2))
w = tf.Variable(tf.random_normal([2, 2]))
bias = tf.Variable(tf.zeros((2, )))
output1 = tf.matmul(x, w) + bias

some_value = tf.constant([3, 3],      # <-- Some tensor the output will be multiplied by
                         dtype=tf.float32)
output1 *= some_value*x  # <-- The output had been multiplied by `some_value`
                         #     (in this case with broadcasting in case of
                         #     more than one input sample)

with tf.control_dependencies([output1]):   # <-- Not necessary, but explicit control
    output2 = tf.matmul(output1, w) + bias #     dependencies is always good practice.

data = np.ones((3, 2)) # 3 two-dimensional samples

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(output2, feed_dict={x:data}))
    # [[3.0432963 3.6584744]
    #  [3.0432963 3.6584744]
    #  [3.0432963 3.6584744]]