我是python和tensorflow的新手。我正在使用MNIST进行线性回归教程。那是我遇到这个问题的时候。 我使用下面的代码下载和提取数据。
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
因此,当我将数据下载到指定位置后运行此程序时。多次运行会导致数据再次下载。我该怎么做才能简单地提取已下载的数据。
答案 0 :(得分:0)
您可以简单地提供确切位置作为read_data_sets
的第一个参数mnist = input_data.read_data_sets("C:/ABC/Desktop/MNIST_data/", one_hot = True)
如果在提取mnist数据时遇到同样的问题,可以使用此代码位
root = os.path.splitext(os.path.splitext(filename)[0])[0]
print(root)
if os.path.isdir(root) and not force:
print('%s already present - %s.'%(root, filename))
else:
print('Extracting data for %s.'%root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall(data_root)
tar.close()
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))
]
此处,文件名指定已提取的tar文件的位置, data_root 是存储数据的目标。您可以将此代码位定义为返回数据文件夹的函数。