在tensorflow中添加新的op - Shape函数

时间:2017-03-20 20:28:33

标签: python c++ function tensorflow shape

我正在尝试在Tensorflow中添加一个新操作,其中我有两个输入,即3D张量和常数,它输出一个4D张量。通过将由常数定义的次数复制三维张量来获得4D张量。 形状函数以下列方式实现:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle output;
    ::tensorflow::shape_inference::ShapeHandle out1 = c->Vector(::tensorflow::shape_inference::DimensionOrConstant(5));
    TF_RETURN_IF_ERROR(c->Concatenate(c->input(0),out1,&output));
    c->set_output(0,output);
    return Status::OK();
})
.Doc(R"doc(
     Replicating the 3D input tensor in a 4D tensor.
)doc");

我希望第四维的大小(由代码中的out1定义)设置为第二个输入(即常量值)。怎么做?

1 个答案:

答案 0 :(得分:0)

也许MakeShapeFromShapeTensor正是您要找的?类似的东西:

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c)
{
    ::tensorflow::shape_inference::ShapeHandle n;
    TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &n));
    ::tensorflow::shape_inference::ShapeHandle out;
    TF_RETURN_IF_ERROR(c->Concatenate(n, c->input(0), &out));
    c->set_output(0, out);
    return Status::OK();
})

那就是说,你可能知道这一点,但只是为了确定:Element-wise arithmetic operations in TensorFlow support broadcasting,所以至少在这种情况下你不需要这个自定义操作。

对于其他情况,您还可以合并tf.tiletf.shapetf.concattf.reshape以达到同样的效果。例如,以下内容通过重复向量来创建矩阵:

import tensorflow as tf
oneD = tf.constant([1,2])
n = tf.constant([5])
twoD = tf.reshape(tf.tile(oneD, n), tf.concat([n, tf.shape(oneD)], 0))

with tf.Session() as sess:
  print oneD.eval()
  print twoD.eval()