Tensorflow实现crf损失

时间:2018-06-29 15:28:57

标签: python tensorflow crf

我正在尝试在Tensorflow图中使用条件随机场损失。

我正在执行序列标记任务:

我有一系列元素作为输入[A, B, C, D]。每个元素可以属于3个不同类中的一个。 类以一键编码的方式表示:属于类0的元素由向量[{1, 0, 0]}表示。

我的输入标签(y)的大小为(batch_size x sequence_length x num_classes)。

我的网络会生成形状相同的logit。

假设我所有序列的长度均为4。

这是我的代码:

import tensorflow as tf

sequence_length = 4
num_classes = 3
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)

log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length)

我收到以下错误:

  

文件“”,第1行,在         文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py”,第182行,在crf_log_likelihood中           transition_params)         文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py”,行109,位于crf_sequence_score中           false_fn = _multi_seq_fn)         在smart_cond中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/utils.py”,行206           pred,true_fn = true_fn,false_fn = false_fn,name = name)         在smart_cond中,文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/smart_cond.py”,第59行           名称=名称)         在new_func中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py”,第432行           return func(* args,** kwargs)         cond中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py”,第2063行           orig_res_t,res_t = context_t.BuildCondBranch(true_fn)         BuildCondBranch中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py”,行1913           original_result = fn()         文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py”,第95行,_single_seq_fn           array_ops.concat([example_inds,tag_indices],轴= 1))         在collect_nd中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py”,第2975行           “ GatherNd”,参数=参数,索引=索引,名称=名称)         _apply_op_helper中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py”,行787           op_def = op_def)         在create_op中,文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py”,行3392           op_def = op_def)          init 中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py”,行1734           control_input_ops)         _create_c_op中的文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py”,行1570           引发ValueError(str(e))       ValueError:indexes.shape [-1]必须为<= params.rank,但是对于'cond / GatherNd'(op:'GatherNd'),索引格式为[?,5]和参数格式为[?,3]。输入形状:[?, 3],[?, 5]

1 个答案:

答案 0 :(得分:1)

该错误归因于序列长度变量的尺寸错误。它必须是向量,而不是标量。

import tensorflow as tf

num_classes = 3
input_x = tf.placeholder(tf.int32, shape=[None, None], name="input_x")
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
sequence_length = tf.reduce_sum(tf.sign(input_x), 1)

# After some network operation you will come up with logits

logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)
log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length