我想构建一个tensorflow包装器来训练模型。我们的想法是你可以在一个函数中定义你的模型,将它传递给object / wrapper,然后它将完成其余的工作。因此,您不必每次都从头开始对所有内容进行编码。我会用一些伪代码清楚地说明一下
def model():
//Define your tf graph/structure here
return output
然后你会有一个课程,你可以将你的模型,训练数据,有效数据传递给它
class tf_wrapper():
def __init___(model,training_data,valid_data):
//init stuffs
def train():
//code to train the model
在许多教程中,列车代码看起来应该是标准的:
for i in range(epochs):
sess.run(feed_dict{placeholder_X: batch_X, placeholder_Y: batchY)
我现在正在努力的是,有不同类型的模型结构,损失函数,输入管道......例如:分类任务的损失函数不同于回归(crossmax entropy vs MSE)也是精度的计算,或者您输入CNN数据的方式与RNN不同。解决这个问题的最佳方法是什么?