尝试提取冻结的张量流图时出错

时间:2018-11-01 11:34:16

标签: python tensorflow keras

我正在尝试加载tf.keras模型,提取冻结的图并将其与tensorframes(特别是来自spark DL库的tensorflow图像转换器)一起使用。

我以前使用keras模型完成了此操作,并且代码运行良好。但是由于某种原因,我无法弄清楚,这个tf.keras模型使用了相同的代码。

SESSION_COOKIE_SAMESITE = None

我收到ValueError:节点batch_normalization / cond / ReadVariableOp / Switch的输入0从batch_normalization / gamma:0传递给float与预期资源不兼容。

import numpy as np
from keras.preprocessing.image import img_to_array, load_img
from keras.applications.vgg16 import preprocess_input
import matplotlib.pyplot as plt
from glob import glob
import os

from sparkdl import KerasImageFileTransformer
from pyspark import SparkContext, SparkConf
from pyspark.sql.types import StringType
from pyspark.sql import SparkSession
from tensorflow.python.keras import losses
import tensorflow as tf
def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

model = tf.keras.models.load_model('/Users/vivek.vanga/Downloads/mask_model_803.h5', custom_objects={"dice_coeff":dice_coeff,"dice_loss":dice_loss,"bce_dice_loss":bce_dice_loss})   
sess = tf.keras.backend.get_session()
print(sess.graph)
#print(sess.run(model.output.name))
import sparkdl.graph.utils as tfx

frozen = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph.as_graph_def(add_shapes=True),
        [tfx.op_name(model.output.name, sess.graph)])
g = tf.Graph()  # pylint: disable=invalid-name
with g.as_default():
    tf.import_graph_def(frozen, name='')

如果您想复制此https://drive.google.com/file/d/1gYAmids4t9f0fLPlI4BfQmFE4v1eKdIq/view?usp=sharing,请使用模型文件。 使用的tf:tensorflow == 1.10.0

0 个答案:

没有答案