将tf.data用于文本数据时,速度/内存性能非常差

时间:2018-05-21 22:40:54

标签: tensorflow

我正在尝试使用此代码

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb

但是使用tf.data来处理所有文本数据。

我使用GPU运行时选项在Google Colaboratory上运行。

这是原始相关代码

from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from sklearn.manifold import TSNE

url = 'http://mattmahoney.net/dc/'

def maybe_download(filename, expected_bytes):
  """Download a file if not present, and make sure it's the right size."""
  if not os.path.exists(filename):
    filename, _ = urlretrieve(url + filename, filename)
  statinfo = os.stat(filename)
  if statinfo.st_size == expected_bytes:
    print('Found and verified %s' % filename)
  else:
    print(statinfo.st_size)
    raise Exception(
      'Failed to verify ' + filename + '. Can you get to it with a browser?')
  return filename

filename = maybe_download('text8.zip', 31344016)

def read_data(filename):
  """Extract the first file enclosed in a zip file as a list of words"""
  with zipfile.ZipFile(filename) as f:
    data = tf.compat.as_str(f.read(f.namelist()[0])).split()
  return data

words = read_data(filename)
print('Data size %d' % len(words))

这是我最好尝试改变它以使用tf.data来处理文本数据。

from __future__ import print_function
import collections
import math
import numpy as np
import os
import random
import tensorflow as tf
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from sklearn.manifold import TSNE

url = 'http://mattmahoney.net/dc/'

def maybe_download(filename, expected_bytes):
  """Download a file if not present, and make sure it's the right size."""
  if not os.path.exists(filename):
    filename, _ = urlretrieve(url + filename, filename)
  statinfo = os.stat(filename)
  if statinfo.st_size == expected_bytes:
    print('Found and verified %s' % filename)
  else:
    print(statinfo.st_size)
    raise Exception(
      'Failed to verify ' + filename + '. Can you get to it with a browser?')
  return filename

filename = maybe_download('text8.zip', 31344016)

def read_data(filename):
  """Extract the first file enclosed in a zip file as a list of words"""
  with zipfile.ZipFile(filename) as f:
    data = tf.data.Dataset.from_tensor_slices( f.read(f.namelist()[0]) )
  return data

datasetTest = read_data(filename)

基本上我从原来的

改变了这一行
data = tf.compat.as_str(f.read(f.namelist()[0])).split()

到这一行

data = tf.data.Dataset.from_tensor_slices( f.read(f.namelist()[0]) )

新的不会像旧行一样按空格分割单词,所以我有另一行#datasetTest = datasetTest.map(lambda string:tf.string_split([string])。values)但我评论了它试着指出瓶颈的位置。

旧代码在一两分钟内运行。新的永远不会完成执行。它通常运行30-40分钟,然后colab说它已经崩溃并重新启动运行时。

0 个答案:

没有答案