Tensorflow lstm增量学习和多重预测

时间:2016-10-10 12:46:18

标签: python machine-learning save tensorflow

我正在训练张量流模型,后来计划将其用于预测。

import numpy as np
import pandas as pd
import sys

import tensorflow as tf
from tensorflow.contrib import learn
from sklearn.metrics import mean_squared_error, mean_absolute_error
from lstm_predictor import load_csvdata, lstm_model

import pymysql as mariadb

LOG_DIR = './ops_logs'
K = 1 # history used for lstm.
TIMESTEPS = 65*K 
RNN_LAYERS = [{'steps': TIMESTEPS}]
DENSE_LAYERS = [10, 10]
TRAINING_STEPS = 1000
BATCH_SIZE = 1
PRINT_STEPS = TRAINING_STEPS / 10

def train_model(symbol=1,categ='M1',limit=1000,upgrade=False):

    MODEL_DIR = 'model/'+str(symbol)+categ

    regressor = learn.TensorFlowEstimator(model_fn=lstm_model(TIMESTEPS, RNN_LAYERS, DENSE_LAYERS),
                                          n_classes=0,
                                          verbose=1,
                                          steps=TRAINING_STEPS,
                                          optimizer='Adagrad',
                                          learning_rate=0.03,
                                          continue_training=True,
                                          batch_size=BATCH_SIZE )

    X, y = load_csvdata(df, K )

    regressor.fit(X['train'], y['train'] , logdir=MODEL_DIR ) #logdir=LOG_DIR)

    X['test']=X['train'][-10:]
    y['test']=y['train'][-10:]
    predicted = regressor.predict(X['test'])
    print('actual', 'predictions')
    for i,yi in enumerate(y['test']):
        print(yi[0], '  ' ,predicted[i])

    mse = mean_absolute_error(y['test'], predicted)
    print ("mean_absolute_error : %f" % mse)

    ###############################

    regressor.save( LOG_DIR )


train_model()

然后我想编写一个预测函数,它将从model/**读取模型并进行预测。

def predict(symbol=1,categ='M1'):
    pass
    # how to load saved model data ?  

但是我无法使用

加载模型
 regressor = learn.TensorFlowEstimator.restore( LOG_DIR )

由于目前尚未实施。

建议我将来如何多次重复预测?

模型检查点保存为:

checkpoint                                        model.ckpt-8001.meta
events.out.tfevents.1476102309.hera.creatory.org  model.ckpt-8301-00000-of-00001
events.out.tfevents.1476102926.hera.creatory.org  model.ckpt-8301.meta
events.out.tfevents.1476105626.hera.creatory.org  model.ckpt-8601-00000-of-00001
events.out.tfevents.1476106521.hera.creatory.org  model.ckpt-8601.meta
events.out.tfevents.1476106839.hera.creatory.org  model.ckpt-8901-00000-of-00001
events.out.tfevents.1476107001.hera.creatory.org  model.ckpt-8901.meta
events.out.tfevents.1476107462.hera.creatory.org  model.ckpt-9000-00000-of-00001
graph.pbtxt                                       model.ckpt-9000.meta
model.ckpt-8001-00000-of-00001                    

0 个答案:

没有答案