我正在尝试使用此代码
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb
但是使用tf.data来处理所有文本数据。
我使用GPU运行时选项在Google Colaboratory上运行。
这是原始相关代码
from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from sklearn.manifold import TSNE
url = 'http://mattmahoney.net/dc/'
def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
filename, _ = urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified %s' % filename)
else:
print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename
filename = maybe_download('text8.zip', 31344016)
def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words"""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
words = read_data(filename)
print('Data size %d' % len(words))
这是我最好尝试改变它以使用tf.data来处理文本数据。
from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from sklearn.manifold import TSNE
url = 'http://mattmahoney.net/dc/'
def maybe_download(filename, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
filename, _ = urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified %s' % filename)
else:
print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename
filename = maybe_download('text8.zip', 31344016)
def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words"""
with zipfile.ZipFile(filename) as f:
data = tf.data.Dataset.from_tensor_slices( f.read(f.namelist()[0]) )
return data
datasetTest = read_data(filename)
基本上我从原来的
改变了这一行data = tf.compat.as_str(f.read(f.namelist()[0])).split()
到这一行
data = tf.data.Dataset.from_tensor_slices( f.read(f.namelist()[0]) )
新的不会像旧行一样按空格分割单词,所以我有另一行#datasetTest = datasetTest.map(lambda string:tf.string_split([string])。values)但我评论了它试着指出瓶颈的位置。
旧代码在一两分钟内运行。新的永远不会完成执行。它通常运行30-40分钟,然后colab说它已经崩溃并重新启动运行时。