从简单的python列表创建图像和标签的Tensorflow数据集

时间:2020-04-29 05:58:58

标签: python tensorflow tensorflow-datasets

我一直在尝试利用各种灵感-特别是this one-创建带标签的图像数据集以传递给model.fit()

我的代码似乎与该问题在the answer中给出的代码相同……与问题的操作相比,_parse_function()略有不同:

def load_image( path, label ):
  file_contents = tf.io.read_file( path )
  image = tf.image.decode_image( file_contents )
  image = tf.image.convert_image_dtype( image, tf.float32 )
  return image, label

我可以在python命令行中独立测试此功能,例如image, label = load_image( "tiger.jpg", "Tiger" ),最后以标签"Tiger"image[0][0]结束,该标签正确对应于图像的左上像素:

>>> image[0][0]
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.37254903, 0.5529412 , 0.854902  ], dtype=float32)>

同样,如果我在程序中尝试print( image[ 0 ][ 0 ],则会得到:

tf.Tensor([0.37254903 0.5529412  0.854902  ], shape=(3,), dtype=float32)

我是python的新手,所以我希望它们只是一个主题的等效变体,但是无论哪种方式,当我将所有内容传递给程序中的model.fit()时,我都会得到:

 ValueError: Cannot take the length of shape with unknown rank.

任何主题的变化都没有使我超越这一点。我从数据集中消除了所有管道操作(例如,没有.shuffle(),没有.repeat(),没有.batch()),因此我只使用了.map()函数,并且得到了相同的结果错误结果。我可以看到的唯一错误可能是在上面的load_image()函数中,或在调用代码中:

dataset = tf.data.Dataset.from_tensor_slices( ( images, labels ) )   # tf.constant() does not change error
dataset = dataset.map( load_map )
model.fit( dataset, epochs=100 )

是什么原因导致错误?

2 个答案:

答案 0 :(得分:0)

ifstream ifile(filename); if(ifile.fail()){ return false; } ifile >> num_of_webpages; for(int i=0; i < num_of_webpages; i++){ Page temp; ifile >> pageid; ifile >> page_URL; ifile >> pagerank; string myline; getline(ifile, myline); stringstream ss(myline); while(ss >> outgoing_link){ temp.add_url_to_links(outgoing_link); cout << outgoing_link << endl; } } 存在一个已知问题-无法正确设置形状信息(请参见here。您可以使用更具体的调用-例如decode_imagedecode_jpeg

此外...您将遇到的下一个问题是您不能直接使用“老虎”之类的标签。如果“老虎”在诸如““狮子”,“老虎”,“斑马”,“猿”,...]等类别的列表中,则您需要在此类列表中使用“老虎”的索引(即decode_png)或单一表示形式(即1

答案 1 :(得分:-1)

请检查此tutorial以获得信息!

您可以首先使用结尾列中的标签和像素作为特征来构建csv文件。然后像这样经历:

titanic_csv_ds = tf.data.experimental.make_csv_dataset(
    titanic_file_path,
    batch_size=5, # Artificially small to make examples easier to show.
    label_name='survived',
    num_epochs=1,
    ignore_errors=True,)