我想在玩具数据集上测试我的网络-带有两个不平衡类(0和1)的几个示例。不幸的是,使用 class_weight 参数来改善平衡时会出现问题。好像我忘记了什么。
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.applications.xception import Xception, preprocess_input
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizers import Adam
# parsing images from TFRecords
def parse_function(proto):
example = {'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}
parsed_example = tf.parse_single_example(proto, example)
image = tf.decode_raw(parsed_example['image_raw'], tf.uint8)
image = tf.reshape(image, [HEIGHT, WIDTH, DEPTH])
image = preprocess_input(tf.cast(image, tf.float32))
return image, parsed_example['label']
def get_data(filepath, schuffle_size=32, batch_size=8, prefetch=1, repeat=None, num_parallel_calls=1):
dataset = tf.data.TFRecordDataset(filepath)
if schuffle_size != 0:
dataset = dataset.shuffle(schuffle_size)
dataset = dataset.repeat(repeat)
dataset = dataset.map(parse_function, num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(prefetch)
iterator = dataset.make_one_shot_iterator()
return iterator
def build_model(number_of_neurons_in_dense_layer, dropout, learning_rate):
base_model = Xception(weights='imagenet', include_top=False, pooling='avg', input_shape=(HEIGHT, WIDTH, 3))
for layer in base_model.layers:
layer.trainable = True
x = base_model.output
x = Dropout(dropout)(x)
x = Dense(number_of_neurons_in_dense_layer, activation='relu')(x)
x = Dropout(dropout)(x)
logits = Dense(NUMBER_OF_CLASSES, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=logits)
model.compile(optimizer=Adam(lr=learning_rate), loss='sparse_categorical_crossentropy', metrics=['categorical_accuracy'])
return model
global NUMBER_OF_CLASSES, HEIGHT, WIDTH, DEPTH
NUMBER_OF_CLASSES = 2
...
CLASS_WEIGHTS = {
0: 1,
1: 7
}
model = build_model(64, 0.4, 0.001)
train = get_data(..., 8, 2, num_parallel_calls=8)
val = get_data(...., 0, 4, num_parallel_calls=8)
model.fit(train, validation_data=val, epochs=3,steps_per_epoch=8//2,
validation_steps=8//4, shuffle=False,
class_weight=CLASS_WEIGHTS)
我遇到以下错误
Original exception was:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 51, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
AttributeError: 'Tensor' object has no attribute 'reshape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/usr/model.py", line 147, in main
class_weight=CLASS_WEIGHTS)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 776, in fit
shuffle=shuffle)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 2432, in _standardize_user_data
feed_sample_weight_modes)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 2431, in <listcomp>
for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights,
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 758, in standardize_weights
y_classes = np.reshape(y, y.shape[0])
File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 279, in reshape
return _wrapfunc(a, 'reshape', newshape, order=order)
File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 61, in _wrapfunc
return _wrapit(obj, method, *args, **kwds)
File "/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py", line 41, in _wrapit
result = getattr(asarray(obj), method)(*args, **kwds)
TypeError: __index__ returned non-int (type NoneType)
没有 class_weight 参数, fit 函数可以正常工作。
答案 0 :(得分:0)
仅供以后参考:
我遇到了这个错误,并且能够通过传递数组而不是字典来解决它。
例如
CLASS_WEIGHTS = np.array([1,7])
代替:
CLASS_WEIGHTS = { 0:1 1:7 }