我注意到在Tensorflow中定义了MatMul op:
Shape函数:
Status MatMulShape(shape_inference::InferenceContext* c) {
ShapeHandle a;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a));
ShapeHandle b;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b));
MatpulOp中的和计算功能:
void Compute(OpKernelContext* ctx) override {
const Tensor& a = ctx->input(0);
const Tensor& b = ctx->input(1);
// Check that the dimensions of the two matrices are valid.
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
errors::InvalidArgument("In[0] is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("In[1] is not a matrix"));
这意味着输入的等级为2,但是以下操作没问题:
a=tf.placeholder(tf.int32, [None, None, None])
b=tf.placeholder(tf.int32, [None, None, None])
c=tf.matmul(a, b)
它包括一个额外的批次暗淡。我想知道它是如何工作的。
我定义了一个ngram op,输入是1级张量:
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sent));
但批量应用时会发生错误:
a = tf.placeholder(tf.int32, [None, None])
c = ngram.ngram(a, vocab_size=5000, bucket_size=100000, word_ngrams=3)
为什么?
答案 0 :(得分:0)
我检查了代码,发现批处理工作应另外完成。 python / ops / math_ops.py中的matmul函数:
def matmul (a, b, ....
...
if (not a_is_sparse and not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and (b_shape is None or len(b_shape) > 2)):
...
return gen_math_ops._batch_mat_mul(a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
python / ops / gen_math_ops.py中的_batch_mat_mul函数:
def _batch_mat_mul(x, y, adj_x=False, adj_y=False, name=None):
...
if _ctx.in_graph_mode():
_, _, _op = _op_def_lib._apply_op_helper("BatchMatMul", x=x, y=y, adj_x=adj_x, adj_y=adj_y, name=name)
在BatchMatMul中计算(tensorflow / core / kernals / batch_matmul_op_impl.h):
void Compute(OpKernelContext* ctx) override {
const Tensor& in0 = ctx->input(0);
const Tensor& in1 = ctx->input(1);
OP_REQUIRES(ctx, in0.dims() == in1.dims(),
errors::InvalidArgument("In[0] and In[1] has different ndims: ",
in0.shape().DebugString(), " vs. ",
in1.shape().DebugString()));
const int ndims = in0.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(",
i, ") must be the same: ",
in0.shape().DebugString(), " vs ",
in1.shape().DebugString()));
out_shape.AddDim(in0.dim_size(i));
}
auto n = (ndims == 2) ? 1 : out_shape.num_elements();
auto d0 = in0.dim_size(ndims - 2);
auto d1 = in0.dim_size(ndims - 1);
Tensor in0_reshaped;
CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
auto d2 = in1.dim_size(ndims - 2);
auto d3 = in1.dim_size(ndims - 1);
Tensor in1_reshaped;
CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
if (adj_x_) std::swap(d0, d1);
if (adj_y_) std::swap(d2, d3);
OP_REQUIRES(ctx, d1 == d2,
errors::InvalidArgument(
"In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
in0.shape().DebugString(), " ", in1.shape().DebugString(),
" ", adj_x_, " ", adj_y_));
out_shape.AddDim(d0);
out_shape.AddDim(d3);
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
return;
}
if (in0.NumElements() == 0 || in1.NumElements() == 0) {
functor::SetZeroFunctor<Device, Scalar> f;
f(ctx->eigen_device<Device>(), out->flat<Scalar>());
return;
}
Tensor out_reshaped;
CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
adj_x_, adj_y_, &out_reshaped);
}
最后,每个矩阵乘法都是通过“运行”函数计算的:
static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
bool adj_y, Tensor* out, int start, int limit) {
for (int i = start; i < limit; ++i) {
auto x = ConstTensorSliceToEigenMatrix(in_x, i);
auto y = ConstTensorSliceToEigenMatrix(in_y, i);
auto z = TensorSliceToEigenMatrix(out, i);
// TODO(rmlarsen): Get rid of the special casing here when we have
// upstreamed improvements for matrix*vector and vector*matrix to
// Eigen's general matrix product.
if (!adj_x && x.rows() == 1) {
Multiply(adj_x, adj_y, x.row(0), y, z);
} else if (adj_x && x.cols() == 1) {
Multiply(adj_x, adj_y, x.col(0), y, z);
} else if (!adj_y && y.cols() == 1) {
Multiply(adj_x, adj_y, x, y.col(0), z);
} else if (adj_y && y.rows() == 1) {
Multiply(adj_x, adj_y, x, y.row(0), z);
} else {
Multiply(adj_x, adj_y, x, y, z);
}
}
}