运行交叉验证时,Tensorflow保存的模型会变大

时间:2016-03-29 08:29:01

标签: tensorflow

在Tensorflow上运行交叉验证的正确方法是什么? 以下是我的代码片段:

$('.qtip-show').each(function() { // Grab all elements with a title attribute,and set "this"
    $(this).qtip({ // 
        position: {
        my: 'top left',  // Position my top left...
        at: 'bottom left', // at the bottom right of...
        target: $(this) // my target
    }
    });
});

折叠0的已保存型号大小约为2M。但是对于4M左右的折叠1,在6M左右折叠2,依此类推。

1 个答案:

答案 0 :(得分:3)

我的猜测是TextCNN构造函数和train()方法正在将节点添加到默认图形(tf.get_default_graph()),并且保存的模型包含所有以前的图形,因此它是“意外的二次曲线“并且随着__main__循环的每次迭代而增长。

幸运的是,解决方案很简单。只需按如下方式重写主循环:

if __name__ == "__main__":
  for i in range(CV_SIZE):
    with tf.Graph().as_default():  # Performs training in a new, empty graph.
      cnn = TextCNN(i)
      cnn.train()

这将为循环的每次迭代创建一个新的空图。因此,保存的模型将不包含上一次迭代中的节点(和变量),并且模型大小应保持不变。

请注意,如果可能,您应该尝试为所有迭代重用相同的图形。但是,我意识到如果图形的结构从一次迭代变为下一次迭代,这可能是不可能的。