我正在训练张量流模型,后来计划将其用于预测。
import numpy as np
import pandas as pd
import sys
import tensorflow as tf
from tensorflow.contrib import learn
from sklearn.metrics import mean_squared_error, mean_absolute_error
from lstm_predictor import load_csvdata, lstm_model
import pymysql as mariadb
LOG_DIR = './ops_logs'
K = 1 # history used for lstm.
TIMESTEPS = 65*K
RNN_LAYERS = [{'steps': TIMESTEPS}]
DENSE_LAYERS = [10, 10]
TRAINING_STEPS = 1000
BATCH_SIZE = 1
PRINT_STEPS = TRAINING_STEPS / 10
def train_model(symbol=1,categ='M1',limit=1000,upgrade=False):
MODEL_DIR = 'model/'+str(symbol)+categ
regressor = learn.TensorFlowEstimator(model_fn=lstm_model(TIMESTEPS, RNN_LAYERS, DENSE_LAYERS),
n_classes=0,
verbose=1,
steps=TRAINING_STEPS,
optimizer='Adagrad',
learning_rate=0.03,
continue_training=True,
batch_size=BATCH_SIZE )
X, y = load_csvdata(df, K )
regressor.fit(X['train'], y['train'] , logdir=MODEL_DIR ) #logdir=LOG_DIR)
X['test']=X['train'][-10:]
y['test']=y['train'][-10:]
predicted = regressor.predict(X['test'])
print('actual', 'predictions')
for i,yi in enumerate(y['test']):
print(yi[0], ' ' ,predicted[i])
mse = mean_absolute_error(y['test'], predicted)
print ("mean_absolute_error : %f" % mse)
###############################
regressor.save( LOG_DIR )
train_model()
然后我想编写一个预测函数,它将从model/**
读取模型并进行预测。
def predict(symbol=1,categ='M1'):
pass
# how to load saved model data ?
但是我无法使用
加载模型 regressor = learn.TensorFlowEstimator.restore( LOG_DIR )
由于目前尚未实施。
建议我将来如何多次重复预测?
模型检查点保存为:
checkpoint model.ckpt-8001.meta
events.out.tfevents.1476102309.hera.creatory.org model.ckpt-8301-00000-of-00001
events.out.tfevents.1476102926.hera.creatory.org model.ckpt-8301.meta
events.out.tfevents.1476105626.hera.creatory.org model.ckpt-8601-00000-of-00001
events.out.tfevents.1476106521.hera.creatory.org model.ckpt-8601.meta
events.out.tfevents.1476106839.hera.creatory.org model.ckpt-8901-00000-of-00001
events.out.tfevents.1476107001.hera.creatory.org model.ckpt-8901.meta
events.out.tfevents.1476107462.hera.creatory.org model.ckpt-9000-00000-of-00001
graph.pbtxt model.ckpt-9000.meta
model.ckpt-8001-00000-of-00001