在Keras的fit函数中使用class_weight参数时发生错误

时间:2019-03-10 23:41:31

标签: python tensorflow keras

我想在玩具数据集上测试我的网络-带有两个不平衡类(0和1)的几个示例。不幸的是,使用 class_weight 参数来改善平衡时会出现问题。好像我忘记了什么。

import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.applications.xception import Xception, preprocess_input
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizers import Adam

# parsing images from TFRecords
def parse_function(proto):
    example = {'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}
    parsed_example = tf.parse_single_example(proto, example)
    image = tf.decode_raw(parsed_example['image_raw'], tf.uint8)
    image = tf.reshape(image, [HEIGHT, WIDTH, DEPTH])
    image = preprocess_input(tf.cast(image, tf.float32))
    return image, parsed_example['label']

def get_data(filepath, schuffle_size=32, batch_size=8, prefetch=1, repeat=None, num_parallel_calls=1):
    dataset = tf.data.TFRecordDataset(filepath)
    if schuffle_size != 0:
        dataset = dataset.shuffle(schuffle_size)
    dataset = dataset.repeat(repeat)
    dataset = dataset.map(parse_function, num_parallel_calls=num_parallel_calls)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(prefetch)
    iterator = dataset.make_one_shot_iterator()
    return iterator

def build_model(number_of_neurons_in_dense_layer, dropout, learning_rate):
    base_model = Xception(weights='imagenet', include_top=False, pooling='avg', input_shape=(HEIGHT, WIDTH, 3))
    for layer in base_model.layers:
        layer.trainable = True
    x = base_model.output
    x = Dropout(dropout)(x)
    x = Dense(number_of_neurons_in_dense_layer, activation='relu')(x)
    x = Dropout(dropout)(x)
    logits = Dense(NUMBER_OF_CLASSES, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=logits)
    model.compile(optimizer=Adam(lr=learning_rate), loss='sparse_categorical_crossentropy', metrics=['categorical_accuracy'])
    return model

global NUMBER_OF_CLASSES, HEIGHT, WIDTH, DEPTH
NUMBER_OF_CLASSES = 2
...
CLASS_WEIGHTS = {
        0: 1,
        1: 7
       }
model = build_model(64, 0.4, 0.001)
train = get_data(..., 8, 2, num_parallel_calls=8)
val = get_data(...., 0, 4, num_parallel_calls=8)
model.fit(train, validation_data=val, epochs=3,steps_per_epoch=8//2,
           validation_steps=8//4, shuffle=False, 
           class_weight=CLASS_WEIGHTS)

我遇到以下错误

   Original exception was:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 51, in _wrapfunc
    return getattr(obj, method)(*args, **kwds)
AttributeError: 'Tensor' object has no attribute 'reshape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/usr/model.py", line 147, in main
    class_weight=CLASS_WEIGHTS)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 776, in fit
    shuffle=shuffle)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 2432, in _standardize_user_data
    feed_sample_weight_modes)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 2431, in <listcomp>
    for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 758, in standardize_weights
    y_classes = np.reshape(y, y.shape[0])
  File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 279, in reshape
    return _wrapfunc(a, 'reshape', newshape, order=order)
  File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 61, in _wrapfunc
    return _wrapit(obj, method, *args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 41, in _wrapit
    result = getattr(asarray(obj), method)(*args, **kwds)
TypeError: __index__ returned non-int (type NoneType)

没有 class_weight 参数, fit 函数可以正常工作。

1 个答案:

答案 0 :(得分:0)

仅供以后参考:

我遇到了这个错误,并且能够通过传递数组而不是字典来解决它。

例如

CLASS_WEIGHTS = np.array([1,7])

代替:

CLASS_WEIGHTS = { 0:1 1:7 }