标量变量的索引无效,每批都具有批次[2]

时间:2018-11-25 19:51:11

标签: python tensorflow machine-learning

代码如下所示。它一直显示标量变量的无效索引。仍然不知道原因。 问题显示在第90行

 batch = each[2][i:i + batch_size]

这会导致问题。 IndexError:标量变量的索引无效。

该功能的完整代码如下所示:

def predict_on_frames(frames_folder, model_file, input_layer, output_layer, batch_size):
input_height = 299
input_width = 299
input_mean = 0
input_std = 255
batch_size = batch_size
graph = load_graph(model_file)

labels_in_dir = os.listdir(frames_folder)
frames = [each for each in os.walk(frames_folder) if os.path.basename(each[0]) in labels_in_dir]

predictions = []
for each in frames:
    label = each[0]
    print("Predicting on frame of %s\n" % (label))
    for i in tqdm(range(0, len(each[2]), batch_size), ascii=True):
        batch = each[2][i:i + batch_size]
        try:
            batch = [os.path.join(label, frame) for frame in batch]
            frames_tensors = read_tensor_from_image_file(batch, input_height=input_height, input_width=input_width, input_mean=input_mean, input_std=input_std)
            pred = predict(graph, frames_tensors, input_layer, output_layer)
            pred = [[each.tolist(), os.path.basename(label)] for each in pred]
            predictions.extend(pred)

        except KeyboardInterrupt:
            print("You quit with ctrl+c")
            sys.exit()

        except Exception as e:
            print("Error making prediction: %s" % (e))
            x = input("\nDo You Want to continue on other samples: y/n")
            if x.lower() == 'y':
                continue
            else:
                sys.exit()
return predictions

0 个答案:

没有答案