Universal Tensorflow Wrapper用于模型训练

时间:2018-01-20 19:25:36

标签: python tensorflow

我想构建一个tensorflow包装器来训练模型。我们的想法是你可以在一个函数中定义你的模型,将它传递给object / wrapper,然后它将完成其余的工作。因此,您不必每次都从头开始对所有内容进行编码。我会用一些伪代码清楚地说明一下

def model():
    //Define your tf graph/structure here
    return output 

然后你会有一个课程,你可以将你的模型,训练数据,有效数据传递给它

class tf_wrapper():
   def __init___(model,training_data,valid_data):
      //init stuffs
   def train():
     //code to train the model

在许多教程中,列车代码看起来应该是标准的:

for i in range(epochs):
    sess.run(feed_dict{placeholder_X: batch_X, placeholder_Y: batchY)

我现在正在努力的是,有不同类型的模型结构,损失函数,输入管道......例如:分类任务的损失函数不同于回归(crossmax entropy vs MSE)也是精度的计算,或者您输入CNN数据的方式与RNN不同。解决这个问题的最佳方法是什么?

0 个答案:

没有答案