应该定义“密集”输入的最后一个维度。发现“无”

时间:2019-02-13 07:48:55

标签: python tensorflow keras

我对tensorflow非常陌生,并尝试为自己的图像集创建一个简单的二进制分类器。它们都是226x226灰度PNG图片。我不断收到错误“ ValueError:应定义Dense的输入的最后维度。找到了None”。我已经坚持了好几天,似乎无法以一种可行的方式来塑造我的模型/数据集。有人可以帮忙吗?任何可能相关的代码都应在下面。预先感谢。

##img parser
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_png(image_string)
  image_decoded = tf.image.resize_images(image_decoded,[226,226])
  return image_decoded, label

#img processor function
#input: dir
#output: dataset
def imgPrcs(dir):
    labelArr = [];
    filenames = [];
    src = dir;

    for fname in os.listdir(src):
        png = os.path.join(src, fname);
        filenames.append(png);
        if os.path.isfile(png):
            #extract label
            with open(png, 'rb') as fobj:
                data = fobj.read()
            data_arr = [];
            for chunk_type, chunk_data in chunk_iter(data):
                if   chunk_type == b'iTXt':
                    data_arr.append(chunk_data.decode());
            label = int(data_arr[1][-1:]);

            #add label
            labelArr.append(label);

    labels = tf.constant(labelArr)
    filename_q = tf.constant(filenames)

    dataset = tf.data.Dataset.from_tensor_slices((filename_q, labels))
    dataset = dataset.map(_parse_function)

    #return variables
    return dataset;

#create labels and datasets
print('Compiling images and labels...\n');
trainData = imgPrcs('./train/');
testData = imgPrcs('./test/');
valData = imgPrcs('./validate/');


#Create Model
print('Creating Model...\n');
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(226, 226, None)),
    keras.layers.Dense(128, kernel_initializer='normal', activation='relu'),
    keras.layers.Dense(1,kernel_initializer='normal', activation='sigmoid')
])

print('compile...\n')
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy']);


print('train..\n')
#Train Model
model.fit(trainData.make_one_shot_iterator(), epochs=5, steps_per_epoch=385)

print('test')
#Test Model
test_loss, test_acc = model.evaluate(testData.make_one_shot_iterator());

print('Test accuracy:', test_acc);

0 个答案:

没有答案