Tensorflow:为tf.map_fn构造params张量

时间:2016-10-18 07:04:29

标签: tensorflow

import tensorflow as tf
import numpy as np


def lineeqn(slope, intercept, y, x):
    return np.sign(y-(slope*x) - intercept)

# data size
DS = 100000

N = 100
x1 = tf.random_uniform([DS], -1, 0, dtype=tf.float32, seed=0)
x2 = tf.random_uniform([DS], 0, 1, dtype=tf.float32, seed=0)

# line representing the target function
rand1 = np.random.randint(0, DS)
rand2 = np.random.randint(0, DS)
T_x1 = x1[rand1]
T_x2 = x1[rand2]
T_y1 = x2[rand1]
T_y2 = x2[rand2]

slope = (T_y2 - T_y1)/(T_x2 - T_x1)
intercept = T_y2 - (slope * T_x2)

# extracting training samples from the data set
training_indices = np.random.randint(0, DS, N)
training_x1 = tf.gather(x1, training_indices)
training_x2 = tf.gather(x2, training_indices)

training_x1_ex = tf.expand_dims(training_x1, 1)
training_x2_ex = tf.expand_dims(training_x2, 1)


slope_tensor = tf.fill([N], slope)
slope_ex = tf.expand_dims(slope_tensor, 1)

intercept_tensor = tf.fill([N], intercept)
intercept_ex = tf.expand_dims(intercept_tensor, 1)

params = tf.concat(1, [slope_ex, intercept_ex, training_x2_ex, training_x1_ex])
training_y = tf.map_fn(lineeqn, params)

lineeqn函数需要4个参数,因此params应该是张量,其中每个元素是4元素张量。当我尝试运行上面的代码时,我收到错误TypeError: lineeqn() takes exactly 4 arguments (1 given)。有人可以解释我构造params张量的方式有什么问题吗? tf.map_fn对params张量有什么作用?

1 个答案:

答案 0 :(得分:2)

有人提出了类似的问题here。您收到此错误的原因是因为map_fn - lineeqn在您的情况下调用的函数 - 需要采用一个张量参数。

而不是函数的参数列表,参数elems应该是 items 的列表,其中映射函数被调用列表中包含的每个项目。 因此,为了给你的函数带来多个参数,你必须自己从每个项目中解压缩它们,例如

def lineeqn(item):
    slope, intercept, y, x = tf.unstack(item, num=4)
    return np.sign(y - (slope * x) - intercept)

并将其命名为

training_y = tf.map_fn(lineeqn, list_of_parameter_tensors)

在这里,您为list_of_parameter_tensors中的每个张量调用线方程,其中每个张量将描述打包参数的元组(slope, intercept, y, x)。 (请注意,根据实际参数张量的形状,也可能不是tf.concat而是使用tf.pack。)