我找到了以下代码,并尝试实现Saver()来确认训练模型实际上可以加载使用但似乎某些变量或某些东西没有正确恢复(可疑变量范围项) 。



import numpy as np
import matplotlib.pyplot as plt

import random
import json
import os
import time

from faker import Faker
import babel
from babel.dates import format_date

import tensorflow as tf

from keras.models import Sequential
from keras.layers import LSTM, Embedding

import tensorflow.contrib.legacy_seq2seq as seq2seq

from sklearn.model_selection import train_test_split


fake = Faker()

_modelPath = os.path.dirname(os.path.abspath(__file__)) + "\Model\date_model.ckpt" 


inputs = tf.placeholder(tf.int32, (None, 29), 'inputs')
outputs = tf.placeholder(tf.int32, (None, None), 'output')
targets = tf.placeholder(tf.int32, (None, None), 'targets')

FORMATS = ['short',
           'd MMM YYY',
           'd MMMM YYY',
           'dd MMM YYY',
           'd MMM, YYY',
           'd MMMM, YYY',
           'dd, MMM YYY',
           'd MM YY',
           'd MMMM YYY',
           'MMMM d YYY',
           'MMMM d, YYY',

# change this if you want it to work with only a single language
LOCALES = babel.localedata.locale_identifiers()
LOCALES = [lang for lang in LOCALES if 'en' in str(lang)]

def create_date():
        Creates some fake dates 
        :returns: tuple containing 
                  1. human formatted string
                  2. machine formatted string
                  3. date object.
    dt = fake.date_object()

        human = format_date(dt, format=random.choice(FORMATS), locale=random.choice(LOCALES))

        case_change = random.randint(0,3) # 1/2 chance of case change
        if case_change == 1:
            human = human.upper()
        elif case_change == 2:
            human = human.lower()

        machine = dt.isoformat()
    except AttributeError as e:
        return None, None, None

    return human, machine #, dt

data = [create_date() for _ in range(50000)]

x = [x for x, y in data]
y = [y for x, y in data]

u_characters = set(' '.join(x))
char2numX = dict(zip(u_characters, range(len(u_characters))))

u_characters = set(' '.join(y))
char2numY = dict(zip(u_characters, range(len(u_characters))))

char2numX['<PAD>'] = len(char2numX)
num2charX = dict(zip(char2numX.values(), char2numX.keys()))
max_len = max([len(date) for date in x])

x = [[char2numX['<PAD>']]*(max_len - len(date)) +[char2numX[x_] for x_ in date] for date in x]
print(''.join([num2charX[x_] for x_ in x[4]]))
x = np.array(x)

char2numY['<GO>'] = len(char2numY)
num2charY = dict(zip(char2numY.values(), char2numY.keys()))

y = [[char2numY['<GO>']] + [char2numY[y_] for y_ in date] for date in y]
print(''.join([num2charY[y_] for y_ in y[4]]))
y = np.array(y)

x_seq_length = len(x[0])
y_seq_length = len(y[0])- 1


def batch_data(x, y, BATCH_SIZE):
    shuffle = np.random.permutation(len(x))
    start = 0
#     from IPython.core.debugger import Tracer; Tracer()()
    x = x[shuffle]
    y = y[shuffle]
    while start + BATCH_SIZE <= len(x):
        yield x[start:start+BATCH_SIZE], y[start:start+BATCH_SIZE]
        start += BATCH_SIZE

# Embedding layers
input_embedding = tf.Variable(tf.random_uniform((len(char2numX), EMBEDDED_SIZE), -1.0, 1.0), name='enc_embedding')
output_embedding = tf.Variable(tf.random_uniform((len(char2numY), EMBEDDED_SIZE), -1.0, 1.0), name='dec_embedding')

date_input_embed = tf.nn.embedding_lookup(input_embedding, inputs)
date_output_embed = tf.nn.embedding_lookup(output_embedding, outputs)

with tf.variable_scope("encoding") as encoding_scope:
    lstm_enc = tf.contrib.rnn.BasicLSTMCell(NUMBER_OF_NODES)
    _, last_state = tf.nn.dynamic_rnn(lstm_enc, inputs=date_input_embed, dtype=tf.float32)

with tf.variable_scope("decoding") as decoding_scope:
    lstm_dec = tf.contrib.rnn.BasicLSTMCell(NUMBER_OF_NODES)
    dec_outputs, _ = tf.nn.dynamic_rnn(lstm_dec, inputs=date_output_embed, initial_state=last_state)

#connect outputs to 
logits = tf.contrib.layers.fully_connected(dec_outputs, num_outputs=len(char2numY), activation_fn=None) 

with tf.name_scope("optimization"):
    # Loss function
    loss = tf.contrib.seq2seq.sequence_loss(logits, targets, tf.ones([BATCH_SIZE, y_seq_length]))
    # Optimizer
    optimizer = tf.train.RMSPropOptimizer(1e-3).minimize(loss)

X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=0.33, random_state=42)

saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=None)) # Add ops to save and restore all the variables.

sess = tf.Session() 

if os.path.isfile(_modelPath + ".index"):

    saver.restore(sess, _modelPath) #Yes, no need to add ".index"
    print('Done Restoring Model.')


    print('Traning Model.')
    for epoch_i in range(NUMBER_OF_EPOCHS):

        for batch_i, (source_batch, target_batch) in enumerate(batch_data(X_train, Y_train, BATCH_SIZE)):
            _, batch_loss, batch_logits = sess.run([optimizer, loss, logits], feed_dict = {inputs: source_batch, outputs: target_batch[:, :-1], targets: target_batch[:, 1:]})

        accuracy = np.mean(batch_logits.argmax(axis=-1) == target_batch[:,1:])
        print('Epoch {:3} Loss: {:>6.3f} Accuracy: {:>6.4f}'.format(epoch_i, batch_loss, accuracy))

    # Save the variables to disk.
    save_path = saver.save(sess, _modelPath)
    print("Retraining Done. Updated model saved in file: %s" % save_path + ' ' + os.path.abspath(save_path))

#Setup test batches
source_batch, target_batch =  next(batch_data(X_test, Y_test, BATCH_SIZE))

print("{} {} total length:{}".format(source_batch[0], target_batch[0], len(source_batch)))


dec_input = np.zeros((len(source_batch), 1)) + char2numY['<GO>']

for i in range(y_seq_length):
    batch_logits = sess.run(logits, feed_dict = {inputs: source_batch,  outputs: dec_input})
    prediction = batch_logits[:,-1].argmax(axis=-1)
    dec_input = np.hstack([dec_input, prediction[:,None]])

print('Accuracy on test set is: {:>6.3f}'.format(np.mean(dec_input == target_batch)))

num_preds = 2
source_chars = [[num2charX[x_index] for x_index in sent if num2charX[x_index]!="<PAD>"] for sent in source_batch[:num_preds]]
dest_chars = [[num2charY[y_index] for y_index in sent] for sent in dec_input[:num_preds, 1:]]

for date_in, date_out in zip(source_chars, dest_chars):
    print(''.join(date_in)+' => '+''.join(date_out))

with open(_picklePath,'wb') as f:
    pickle.dump([char2numX, num2charX, char2numY, num2charY],f)


with open(_picklePath,'rb') as f:
    [char2numX, num2charX, char2numY, num2charY] = pickle.load(f)
