恢复训练模型后如何正确重置图形和变量

时间:2018-02-03 02:12:02

标签: python tensorflow scope restore

我找到了以下代码,并尝试实现Saver()来确认训练模型实际上可以加载使用但似乎某些变量或某些东西没有正确恢复(可疑变量范围项) 。

相关代码检查模型是否存在,如果模型不存在,它将训练并保存模型,如果模型存在,它将尝试恢复模型然后对测试数据运行预测。在第一次训练时,预测测试数据工作正常,但是在运行和加载训练模型时后续运行,测试数据预测失败(给出乱码数据[日期数/格式])。

非常感谢任何协助。

import numpy as np
import matplotlib.pyplot as plt

import random
import json
import os
import time

from faker import Faker
import babel
from babel.dates import format_date

import tensorflow as tf

from keras.models import Sequential
from keras.layers import LSTM, Embedding

import tensorflow.contrib.legacy_seq2seq as seq2seq

from sklearn.model_selection import train_test_split

NUMBER_OF_EPOCHS = 10
BATCH_SIZE = 128
NUMBER_OF_NODES = 32
EMBEDDED_SIZE = 10

fake = Faker()
fake.seed(42)
random.seed(42)

_modelPath = os.path.dirname(os.path.abspath(__file__)) + "\Model\date_model.ckpt" 

tf.reset_default_graph()

inputs = tf.placeholder(tf.int32, (None, 29), 'inputs')
outputs = tf.placeholder(tf.int32, (None, None), 'output')
targets = tf.placeholder(tf.int32, (None, None), 'targets')

FORMATS = ['short',
           'medium',
           'long',
           'full',
           'd MMM YYY',
           'd MMMM YYY',
           'dd MMM YYY',
           'd MMM, YYY',
           'd MMMM, YYY',
           'dd, MMM YYY',
           'd MM YY',
           'd MMMM YYY',
           'MMMM d YYY',
           'MMMM d, YYY',
           'dd.MM.YY',
           ]

# change this if you want it to work with only a single language
LOCALES = babel.localedata.locale_identifiers()
LOCALES = [lang for lang in LOCALES if 'en' in str(lang)]

def create_date():
    """
        Creates some fake dates 
        :returns: tuple containing 
                  1. human formatted string
                  2. machine formatted string
                  3. date object.
    """
    dt = fake.date_object()

    try:
        human = format_date(dt, format=random.choice(FORMATS), locale=random.choice(LOCALES))

        case_change = random.randint(0,3) # 1/2 chance of case change
        if case_change == 1:
            human = human.upper()
        elif case_change == 2:
            human = human.lower()

        machine = dt.isoformat()
    except AttributeError as e:
        return None, None, None

    return human, machine #, dt

data = [create_date() for _ in range(50000)]

x = [x for x, y in data]
y = [y for x, y in data]

u_characters = set(' '.join(x))
char2numX = dict(zip(u_characters, range(len(u_characters))))

u_characters = set(' '.join(y))
char2numY = dict(zip(u_characters, range(len(u_characters))))

char2numX['<PAD>'] = len(char2numX)
num2charX = dict(zip(char2numX.values(), char2numX.keys()))
max_len = max([len(date) for date in x])

x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]
print(''.join([num2charX[x_] for x_ in x[4]]))
x = np.array(x)

char2numY['<GO>'] = len(char2numY)
num2charY = dict(zip(char2numY.values(), char2numY.keys()))

y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in y]
print(''.join([num2charY[y_] for y_ in y[4]]))
y = np.array(y)

x_seq_length = len(x[0])
y_seq_length = len(y[0])- 1

print(x_seq_length)

def batch_data(x, y, BATCH_SIZE):
    shuffle = np.random.permutation(len(x))
    start = 0
#     from IPython.core.debugger import Tracer; Tracer()()
    x = x[shuffle]
    y = y[shuffle]
    while start + BATCH_SIZE <= len(x):
        yield x[start:start+BATCH_SIZE], y[start:start+BATCH_SIZE]
        start += BATCH_SIZE

# Embedding layers
input_embedding = tf.Variable(tf.random_uniform((len(char2numX), EMBEDDED_SIZE), -1.0, 1.0), name='enc_embedding')
output_embedding = tf.Variable(tf.random_uniform((len(char2numY), EMBEDDED_SIZE), -1.0, 1.0), name='dec_embedding')

date_input_embed = tf.nn.embedding_lookup(input_embedding, inputs)
date_output_embed = tf.nn.embedding_lookup(output_embedding, outputs)

with tf.variable_scope("encoding") as encoding_scope:
    lstm_enc = tf.contrib.rnn.BasicLSTMCell(NUMBER_OF_NODES)
    _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32)

with tf.variable_scope("decoding") as decoding_scope:
    lstm_dec = tf.contrib.rnn.BasicLSTMCell(NUMBER_OF_NODES)
    dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=date_output_embed, initial_state=last_state)

#connect outputs to 
logits = tf.contrib.layers.fully_connected(dec_outputs, num_outputs=len(char2numY), activation_fn=None) 

with tf.name_scope("optimization"):
    # Loss function
    loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([BATCH_SIZE, y_seq_length]))
    # Optimizer
    optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)

X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=0.33, random_state=42)

saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)) # Add ops to save and restore all the variables.

sess = tf.Session() 
sess.run(tf.global_variables_initializer())

if os.path.isfile(_modelPath + ".index"):

    saver.restore(sess, _modelPath) #Yes, no need to add ".index"
    print('Done Restoring Model.')


else:

    print('Traning Model.')
    for epoch_i in range(NUMBER_OF_EPOCHS):

        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, Y_train, BATCH_SIZE)):
            _, batch_loss, batch_logits = sess.run([optimizer, loss, logits], feed_dict = {inputs: source_batch, outputs: target_batch[:, :-1], targets: target_batch[:, 1:]})

        accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])
        print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f}'.format(epoch_i, batch_loss, accuracy))

    # Save the variables to disk.
    save_path = saver.save(sess, _modelPath)
    print("Retraining Done. Updated model saved in file: %s" % save_path + ' ' + os.path.abspath(save_path))

#Setup test batches
source_batch, target_batch =  next(batch_data(X_test, Y_test, BATCH_SIZE))

print("{} {} total length:{}".format(source_batch[0], target_batch[0], len(source_batch)))

print(char2numY['<GO>'])

dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']

for i in range(y_seq_length):
    batch_logits = sess.run(logits, feed_dict = {inputs: source_batch,  outputs: dec_input})
    prediction = batch_logits[:,-1].argmax(axis=-1)
    dec_input = np.hstack([dec_input, prediction[:,None]])

print('Accuracy on test set is: {:>6.3f}'.format(np.mean(dec_input == target_batch)))

num_preds = 2
source_chars = [[num2charX[x_index] for x_index in sent if num2charX[x_index]!="<PAD>"] for sent in source_batch[:num_preds]]
dest_chars = [[num2charY[y_index] for y_index in sent] for sent in dec_input[:num_preds, 1:]]

for date_in, date_out in zip(source_chars, dest_chars):
    print(''.join(date_in)+' => '+''.join(date_out))

1 个答案:

答案 0 :(得分:0)

发现了这个问题。问题是我的词典(num和char查找)在每次运行时被随机化,并且从训练模型到后续运行的查找是不同的。所以我需要在模型训练时挑选这些查找,它们是一组给定的值,这些值需要与训练的模型相关联。

因此,当我最初训练模型时,我需要这样做以保存查找:

with open(_picklePath,'wb') as f:
    pickle.dump([char2numX, num2charX, char2numY, num2charY],f)

然后当再次运行预测时,我需要重新加载这些查找:

with open(_picklePath,'rb') as f:
    [char2numX, num2charX, char2numY, num2charY] = pickle.load(f)

当然,我需要重新安排原始代码以适应/完成此操作。现在问题是,用“Saver”对象完成酸洗吗?如果是这样,怎么样?