使用2D CNN进行视频数据的Tensorflow-培训有问题吗?

时间:2019-03-04 14:51:41

标签: python-3.x tensorflow deep-learning conv-neural-network

我想知道是否可以通过将网络映射到各个帧上来在视频数据上训练2D CNN:

def mobile_netv1(inputs, layer):
    with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
    layer_endpoints = tf.map_fn(lambda x: mobilenet_v1.mobilenet_v1
                    (x, num_classes=1001, is_training=True, reuse=tf.AUTO_REUSE)[1][layer], inputs)
return layer_endpoints

其中输入将为[批,帧,高度,宽度,通道],并且端点将在完全连接的层中使用以创建logit。

虽然这个小例子对分类没有意义,但我想知道训练过程中是否存在问题,例如在Adam优化器中,权重的平均值无法正确计算(仅在示例的子集上)。

还是最好将前两个维度[batch *框架]合并起来,然后再将其重塑为[batch,框架]?

编辑: 一个最小的工作示例:

def mobile_netv1(inputs, layer):
  with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope()):
    layer_endpoints = tf.map_fn(lambda x: mobilenet_v1.mobilenet_v1
                (x, num_classes=1001, is_training=True, reuse=tf.AUTO_REUSE)[1][layer], inputs)
  return layer_endpoints


def model_fn(features, labels, mode):
  output = mobile_netv1(features, 'Conv2d_12_depthwise')
  mul_dims = 1
  for i in output.get_shape().as_list()[1:]:
    mul_dims *= i
  flat_output = tf.reshape(output, (-1, mul_dims))
  logits = tf.layers.dense(flat_output, 10)
  one_hot_labels = tf.one_hot(labels, depth=10)
  with tf.variable_scope('loss'):
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_labels, logits=logits)
    optimizer = tf.train.AdamOptimizer()
    m_loss = tf.reduce_mean(loss)
    train_op = optimizer.minimize(m_loss, tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode, loss=m_loss, train_op=train_op)


def input_fn():
  dataset1 = tf.data.Dataset.from_tensor_slices({"features": tf.ones([10, 10, 128, 128, 3]),
                                                "labels": tf.ones([10, 1], dtype=tf.int64)})
  dataset1 = dataset1.repeat()
  dataset1 = dataset1.batch(batch_size)
  iterator = dataset1.make_one_shot_iterator()
  next_element = iterator.get_next()
  return next_element['features'], next_element['labels']

if __name__ == '__main__':
  classifier = tf.estimator.Estimator(model_fn, model_dir='./run')
  classifier.train(input_fn, steps=2)

0 个答案:

没有答案