如何调整图像从灰度到RGB的变化?

时间:2017-04-27 04:16:58

标签: tensorflow tflearn

我似乎总是陷入困境,迷失在如何重塑数据以适应模型。我认为输入和输出数据的形状必须匹配,但我一直迷失在如何解决这个问题上。

我认为我的主要问题是灰度图像和RGB图像的存储方式不同。 [1] vs [255,255,255]

所以如果:

screen = cv2.cvtColor(screen,cv2.COLOR_BGR2RGB)

更改为:

screen = cv2.cvtColor(screen,cv2.COLOR_BGR2GRAY)

一切正常。

有问题的代码:

# Capture Data (CUT SHORT)
WIDTH = 160
HEIGHT = 120
screen = cv2.cvtColor(screen, cv2.COLOR_BGR2RGB)
screen = cv2.resize(screen, (WIDTH, HEIGHT))
dataset = []
output = [0, 0, 0, 0]
dataset.append([screen, output])
np.save("training.npy", dataset)

# Build Model
https://github.com/tflearn/tflearn/blob/master/examples/images/alexnet.py

# Changed to match output.
network = fully_connected(network, 4, activation='softmax')

# Train Data
WIDTH = 160
HEIGHT = 120
LR = 1e-3
EPOCHS = 5
MODEL_NAME = "HELP"

model = alexnet(WIDTH, HEIGHT, LR)

for i in range(EPOCHS):
    train_data = np.load("training.npy".format(i))

    train = train_data[:-100]
    test = train_data[-100:]

    X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,1)
    Y = [i[1] for i in train]

    test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
    test_y = [i[1] for i in test]

    model.fit({'input': X}, {'targets': Y}, n_epoch=1, validation_set=({'input': test_x}, {'targets': test_y}), 
        snapshot_step=500, show_metric=True, run_id=MODEL_NAME)

    model.save(MODEL_NAME)

错误: 线程Thread-3中的异常: Traceback(最近一次调用最后一次):   文件" C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ threading.py",第914行,在_bootstrap_inner中     self.run()   文件" C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ threading.py",第862行,运行中     self._target(* self._args,** self._kwargs)   文件" C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ data_flow.py",第187行,在fill_feed_dict_queue中     data = self.retrieve_data(batch_ids)   在retrieve_data中的文件" C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ data_flow.py",第222行     utils.slice_array(self.feed_dict [key],batch_ids)   文件" C:\ Users \ TF \ AppData \ Local \ Programs \ Python \ Python35 \ lib \ site-packages \ tflearn \ utils.py",第187行,在slice_array中     返回X [开始]

IndexError:索引2936超出了轴0的大小为1900

的范围

1 个答案:

答案 0 :(得分:0)

博士。 Robert Kirchgessner: 输入数据集中有三个通道。

np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,3)

在alexnet:

network = input_data(shape=[None, width, height, 3], name='input')