如何在TfLearn或TensorFlow中创建自定义指标?

时间:2017-09-07 21:50:52

标签: python numpy tensorflow tflearn metric

TensorFlow提供了许多常见的评估指标,但我不知道如何创建自己的指标。 我正在建立一个基于AlexNet的CNN模型用于抓取检测,我想在评估数据时使用矩形度量(如本文所述:https://arxiv.org/pdf/1412.3128.pdf)。矩形度量意味着满足以下两个条件:

- The grasp angle is within 30 degree of the ground truth grasp.
- The Jaccard index of the predicted grasp and the ground truth is greater than 25 percent.

所以我的第一次尝试是使用在TFLearn(https://github.com/tflearn/tflearn/blob/master/examples/images/alexnet.py)上可用的AlexNet模型,并创建一个用numpy计算度量的文件。以下是包含不完整代码的度量标准文件(因为我不允许共享),但主要部分如下:

def grasp_error(grasps,targets,max_angle = 30,min_overlap=0.25):
        return np.mean([np.max([grasp_classification(grasps[i],targets,max_angle,min_overlap) for i in range(grasps.shape[0])])]) #for target in targets[i]])

#compute the error of the test set
def grasp_classification(grasp,target,max_angle = 30,min_overlap = 0.25):
    ...
    if abs(np.arctan2(grasp[sinpos],grasp[cospos]) - np.arctan2(target[sinpos],target[cospos]))< (max_angle * 2./180.)*np.pi:
        if jaccard_index(grasp,target) > min_overlap:
            return 1
    return 0

# computes Jaccard index of two grasping rectangeles
def jaccard_index(grasp,target):
    ...
    return intersect/overall

我尝试将其添加到Tflearn文件夹中的metrics.py文件中:

class Rectangle(Metric):
def __init__(self, name="Rechtangle"):
    super(Rectangle, self).__init__(name)
    self.tensor = None

def build(self, predictions, targets, inputs=None):
    with tf.name_scope('Rechtangle'): # <--------- name scope
         with tf.Session() as sess:
              #tf.InteractiveSession()
              prediction = predictions.eval()
              target = targets.eval()
              self.tensor = tf.convert_to_tensor(grasp_error(prediction, target, max_angle = 30,min_overlap=0.25))
    self.built = True
    self.tensor.m_name = self.name
    return self.tensor

然后在AlexNet的末尾使用它:

rect_metric = tflearn.metrics.Rectangle()
network = regression(network, metric=rect_metric, optimizer='momentum',
                     loss='mean_square',
                     learning_rate=0.0005)

我收到了这个错误:

 File "gnet.py", line 57, in <module>
    learning_rate=0.0005)   
  File "/usr/local/lib/python2.7/dist-packages/tflearn/layers/estimator.py", line 159, in regression
    metric.build(incoming, placeholder, inputs)
  File "/usr/local/lib/python2.7/dist-packages/tflearn/metrics.py", line 119, in build
    prediction = predictions.eval()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 569, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3741, in _eval_using_default_session
    return session.run(tensors, feed_dict)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 778, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 982, in _run
    feed_dict_string, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1032, in _do_run
    target_list, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1052, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value is_training
     [[Node: is_training/read = Identity[T=DT_BOOL, _class=["loc:@is_training"], _device="/job:localhost/replica:0/task:0/cpu:0"](is_training)]]

在metrics文件中实现sess.eval()是个问题,但它是将张量变为numpy数组的唯一方法,不是吗?如果您有任何想法可以解决这个问题,请告诉我。非常感谢你!

编辑:我尝试了另一种方式,按照此处的建议:https://github.com/tflearn/tflearn/issues/207并在代码中实现:

def rect_metric(prediction, target, inputs):
    x = []
    sess = tf.InteractiveSession()
    with sess as default:
        pred = prediction.eval(session=sess)
        tar =  target.eval(session=sess)
        x = tf.reduce_sum(grasp_error(pred,tar))
    return x

现在错误没有出现,但训练因此异常而停止:

Reminder: Custom metric function arguments must be defined as: custom_metric(y_pred, y_true, x).

0 个答案:

没有答案