make tf.Estimator使用默认图表

时间:2017-09-11 18:19:14

标签: tensorflow

我正在尝试使用tensorflow protobuffer进料管道。最简单的方法似乎是使用tf.estimator.Estimatortf.contrib.data.TFRecordDataset。但是,我遇到的问题是,尽管在with g.as_default()内启动,它仍会创建一个新图表。在下面的代码中,我看到TFRecordDataset返回的模型张量和张量在我将它们提供给Estimator之前是相同的,但在Estimator内变得不同。有关如何将它们放在同一图表上的任何想法吗?

# coding: utf-8
import sys
import tensorflow as tf
from keras.applications.inception_v3 import InceptionV3
import numpy as np

final_activation='linear'
g = tf.Graph()
with g.as_default():
    model = InceptionV3(weights='imagenet',
                                include_top=True,
                                input_tensor=None,
                                input_shape=None,
                                pooling=None,
                                classes=1000)

    def model_fn(mode, features, labels, params):
        optimizer = params["optimizer"]
        opt_params= params.get("opt_params", {})
        predictions = model(features)

        if (mode == tf.estimator.ModeKeys.TRAIN or
            mode == tf.estimator.ModeKeys.EVAL):
            loss = tf.contrib.keras.backend.categorical_crossentropy(predictions, labels)
            #loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logyhat)
        else:
            loss = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            optimizer = getattr(tf.train, optimizer)
            train_op = optimizer(opt_params).minimize(loss)
        else:
            train_op = None

        return tf.estimator.EstimatorSpec(
              mode=mode,
              predictions=predictions,
              loss=loss,
              train_op=train_op)

    def parser(record):
        keys_to_features = {
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
            }
        features = tf.parse_single_example(
          record,
          features=keys_to_features)

        # Convert from a scalar string tensor to a uint8 tensor 
        image = tf.decode_raw(features['image_raw'], tf.float32)

        height = tf.cast(features['height'], tf.int32)
        width = tf.cast(features['width'], tf.int32)

        image_shape = tf.stack([height, width, 3])

        image = tf.reshape(image, image_shape)
        label = tf.cast(features["label"], tf.int32)
        return image, label

    def get_dataset_inp_fn(filenames, epochs=20):
        def dataset_input_fn():
            dataset = tf.contrib.data.TFRecordDataset(filenames)
            # Use `Dataset.map()` to build a pair of a feature dictionary and a label
            # tensor for each example.
            dataset = dataset.map(parser)
            dataset = dataset.shuffle(buffer_size=10000)
            dataset = dataset.batch(32)
            dataset = dataset.repeat(epochs)
            iterator = dataset.make_one_shot_iterator()

            features, labels = iterator.get_next()
            return features, labels
        return dataset_input_fn


    inpfun = get_dataset_inp_fn(["mydataset.tfrecords"], epochs=20)
    x,y = inpfun()
    print("X", x.graph)
    print("DEFAULT", g)
    print("MODEL", model.input.graph)
    # everything is on the same graph

    if not x.graph is tf.get_default_graph():
        raise ValueError()

with tf.Session(graph=g) as sess:
    est = tf.estimator.Estimator(
            model_fn,
            model_dir=None,
            config=None,
            params={"optimizer": "AdamOptimizer",
                    "opt_params":{}}
            )
    est.train(inpfun)

0 个答案:

没有答案