如何使用自定义数据集通过PyTorch / few-shot-vid2vid

时间:2020-02-24 09:43:46

标签: python machine-learning computer-vision pytorch nvidia

我想使用从FaceForensics footagefew-show-vid2vid创建的我自己的数据集。因此,我使用ffmpeg生成了图像序列,并使用dlib生成了关键点。当我尝试启动训练脚本时,出现以下错误。到底是什么问题?提供的小型数据集对我有用。

CustomDatasetDataLoader
485 sequences
dataset [FaceDataset] was created
Resuming from epoch 1 at iteration 0
create web directory ./checkpoints/face/web...
---------- Networks initialized -------------
---------- Optimizers initialized -------------
./checkpoints/face/latest_net_G.pth not exists yet!
./checkpoints/face/latest_net_D.pth not exists yet!
model [Vid2VidModel] was created
Traceback (most recent call last):
  File "train.py", line 73, in <module>
    train()
  File "train.py", line 40, in train
    for idx, data in enumerate(dataset, start=trainer.epoch_iter):
  File "/home/keno/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 819, in __next__
    return self._process_data(data)
  File "/home/keno/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 846, in _process_data
    data.reraise()
  File "/home/keno/.local/lib/python3.7/site-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/keno/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/keno/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/keno/.local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/keno/repos/few-shot-vid2vid/data/fewshot_face_dataset.py", line 103, in __getitem__
    Li = self.get_face_image(keypoints, transform_L, ref_img.size)
  File "/home/keno/repos/few-shot-vid2vid/data/fewshot_face_dataset.py", line 168, in get_face_image
    x = keypoints[sub_edge, 0]
IndexError: index 82 is out of bounds for axis 0 with size 82

编辑:如何设置数据集。我按照提供的样本数据集的目录结构,使用ffmpeg -i _video_ -o %05d.jpg从视频镜头创建了图像序列。然后,基于代码示例provided on the dlib website,通过对dlib使用界标检测来生成关键点。我将示例代码扩展到68点,并将其保存到.txt文件中:

import re
import sys
import os
import dlib
import glob

# if len(sys.argv) != 4:
#     print(
#         "Give the path to the trained shape predictor model as the first "
#         "argument and then the directory containing the facial images.\n"
#         "For example, if you are in the python_examples folder then "
#         "execute this program by running:\n"
#         "    ./face_landmark_detection.py shape_predictor_68_face_landmarks.dat ../examples/faces\n"
#         "You can download a trained facial shape predictor from:\n"
#         "    http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2")
#     exit()

predictor_path = sys.argv[1]
faces_folder_path = sys.argv[2]
text_file_path = sys.argv[3]

detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(predictor_path)
win = dlib.image_window()

for f in glob.glob(os.path.join(faces_folder_path, "*.jpg")):
    file_number = os.path.split(f)
    print(file_number[1])
    file_number = os.path.splitext(file_number[1])
    file_number = file_number[0]
    export_path = os.path.join(text_file_path, '%s.txt' % file_number)
    text = open(export_path,"w+")

    print("Processing file: {}".format(f))
    img = dlib.load_rgb_image(f)

    win.clear_overlay()
    win.set_image(img)

    # Ask the detector to find the bounding boxes of each face. The 1 in the
    # second argument indicates that we should upsample the image 1 time. This
    # will make everything bigger and allow us to detect more faces.
    dets = detector(img, 1)
    print("Number of faces detected: {}".format(len(dets)))
    for k, d in enumerate(dets):
        print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
            k, d.left(), d.top(), d.right(), d.bottom()))
        # Get the landmarks/parts for the face in box d.
        shape = predictor(img, d)
        for i in range(67):
            result = str(shape.part(i))
            result = result.strip("()")
            print(result)
            text.write(result + '\n')
        # Draw the face landmarks on the screen.
        win.add_overlay(shape)

    text.close()
    win.add_overlay(dets)

1 个答案:

答案 0 :(得分:1)

对于范围在67的i:

这是不正确的,您应该对68个面部界标使用range(68)。您可以使用python -c "for i in range(67): print(i)"进行验证,该计数仅从0到66(总数为67)进行计数。 python -c "for i in range(68): print(i)"将从0到67(共68个项目)计数,并获得整个人脸标志集。