加载数据集以训练模型

时间:2020-04-12 16:10:05

标签: python python-3.x lstm cnn

#load the data 
def load_photos(filename):
    file = load_doc("C:/Users/Project/Flickr8k_text/Flickr_8k.trainImages.txt")
    photos = file.split("\n")[:-1]
    return photos

def load_clean_descriptions(filename, photos): 
    #loading clean_descriptions
    file = load_doc("C:/Users/Project/descriptions.txt")
    descriptions = {}
    for line in file.split("\n"):
        words = line.split()
        if len(words)<1 :
            continue
        image, image_caption = words[0], words[1:]
        if image in photos:
            if image not in descriptions:
                descriptions[image] = []
            desc = '<start> ' + " ".join(image_caption) + ' <end>'
            descriptions[image].append(desc)
    return descriptions

def load_features(filename):
    #loading all features
    all_features = load(open("C:/Users/Project/features.p","rb"))
    #selecting only needed features
    features = {k:all_features[k] for k in photos}
    return features
filename = dataset_text + "/" + "C:/Users/Project/Flickr8k_text\Flickr8k.token.txt"
#"C:/Users/Project/Flickr8k_text/Flickr_8k.trainImages.txt"
#train = loading_data(filename)
train_imgs = load_photos(filename)
train_descriptions = load_clean_descriptions(filename,photos)
train_features = load_features(train_imgs)

错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-111-e33556710077> in <module>
      8 #"C:/Users/Project/Flickr8k_text/Flickr_8k.trainImages.txt"
      9 #train = loading_data(filename)
---> 10 train_imgs = load_photos(filename)
     11 train_descriptions = load_clean_descriptions(filename,photos)
     12 train_features = load_features(train_imgs)

<ipython-input-108-32aceb470773> in load_photos(filename)
      1 #load the data
      2 def load_photos(filename):
----> 3     file = load_doc("C:/Users/Project/Flickr8k_text/Flickr_8k.trainImages.txt")
      4     photos = file.split("\n")[:-1]
      5     return photos

TypeError: load_doc() takes 0 positional arguments but 1 was given

我正在从事“使用Lstm和Cnn生成图像标题”的工作,如果有兴趣帮助我的人,我会在模型培训中感到震惊,请在下面评论 注意:我是使用正确的数据集和路径来训练模型的 请重播笔记

1 个答案:

答案 0 :(得分:0)

def load_doc(filename):

    file = open(filename, 'r')
    text = file.read()
    file.close()
    return text
 
def load_photos(filename):
    file = load_doc(filename)
    photos = file.split("\n")[:-1]
    return photos
filename = "/content/drive/MyDrive/Dataset/Flickr8k_text/Flickr_8k.trainImages.txt"
train_imgs = load_photos(filename)