创建矩阵 - 操作张量

时间:2017-06-15 14:51:34

标签: python numpy tensorflow

我正在尝试在TensorFlow中实现一种非线性过滤器,但是我在执行一步时遇到了问题。该步骤基本上类似于:

x_update = x.assign(tf.matmul(A, x))

问题是矩阵A的结构类似于:

A = [[1, 0.1, 0, 0, 0],
     [0, 1, 0, 0, 0],
     [0, 0, f1(x), f2(x), f3(x)],
     [0, 0, f4(x), f5(x), f6(x)],
     [0, 0, 0, 0, 1]]

每个fn(x)是我国家的非线性函数;类似于tf.sin(x[4])甚至x[2]**2 * tf.sin(x[4]) + x[3]**2 * tf.cos(x[4])

我不知道如何创建我的A矩阵,以便嵌入这些操作。我首先用一些值初始化它:

A_mat = np.eye(5)
A_mat[0, 1] = 0.1
A = tf.Variable(A_mat, dtype=tf.float32, trainable=False, name='A')

然后我尝试使用tf.scatter_update进行一些切片更新,例如:

# Define my nonlinear operations.
f1 = tf.cos(...)
f2 = tf.sin(...)
# ...

# Define the part that I want to substitute.
new_part = tf.constant(tf.convert_to_tensor([[f1, f2, f3],
                                             [f4, f5, f6]]))

# Define slice indices and update the matrix.
inds = [vals for vals in zip(np.arange(1, 3), np.arange(2, 5))]
A_update = tf.scatter_update(A, tf.constant(inds), new_part, name='A_update')

这给我一个错误说明:

  

ValueError:形状必须等于等级,但是为1和0

     

将形状1与其他形状合并。对于带有输入形状的'packed / 0'(op:'Pack'):[1],[1],[],[],[],[]。

我还尝试将我的矩阵new_part分配回numpy定义的A_mat,但我得到了一个不同的错误,我认为这是由于数字数组突然出现时的意外数据类型分配了张量元素。

那么有人知道如何定义在使用矩阵时更新的操作矩阵吗?

理想情况下,我想定义矩阵A,以便在A内更新的所有操作都是A调用的一部分,并自动发生。这样我可以完全避免切片分配,并且只会感觉更多TensorFlow-y。

谢谢!

更新

我通过将tf.reshape(op_name, [])中的操作包装起来并将更新更改为:

来解决错误
new_part = tf.convert_to_tensor([[0, 0, f1, f2, f3],
                                 [0, 0, f4, f5, f6]]))
rows = np.arange(start_row, end_row)
A_update = tf.scatter_update(A, rows, new_part, name='A_update')

事实证明tf.scatter_update只能在变量的第一个维度上运行,因此我必须向其提供完整的行并将行索引放在我想要的位置。这有帮助,但仍然存在我的问题:

我的问题:

定义此A矩阵的最佳,最TensorFlow-y方式是什么,以便那些常量元素保持不变,并且我图上其他张量运算的元素嵌入{{1 }} 因此?我希望在我的图表上调用A来完成并运行这些更新,而无需手动执行此操作A。或者这是正确的方法吗?

1 个答案:

答案 0 :(得分:2)

更新子矩阵的最简单方法是使用tensorflow的python切片操作。

Class<? extends MyConcreteClass> dynamicType = new ByteBuddy()
    .subclass(MyConcreteClass.class, ConstructorStrategy.Default.DEFAULT_CONSTRUCTOR)
    .name(dynamicClassName)
    .annotateType(MyConcreteClass.class.getDeclaredAnnotations())
    .make()
    .load(MyConcreteClass.class.getClassLoader())
    .getLoaded();