在Tensorflow中配置GPU内存使用情况

时间:2019-08-21 09:12:16

标签: python tensorflow

我有LSTM模型,喜欢配置内存使用情况。 我所做的是

class ActivityRecognition:
     #Utility functions for training:
     def LSTM_RNN(self,_X, _weights, _biases):
          # model architecture based on "guillaume-chevalier" and "aymericdamien" under the MIT license.
          _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size
          _X = tf.reshape(_X, [-1, self.n_input])
          # Rectifies Linear Unit activation function used
          _X = tf.nn.relu(tf.matmul(_X, _weights['hidden']) + _biases['hidden'])
          # Split data because rnn cell needs a list of inputs for the RNN inner loop
          _X = tf.split(_X, self.n_steps, 0)

          # Define two stacked LSTM cells (two recurrent layers deep) with tensorflow
          lstm_cell_1 = tf.contrib.rnn.BasicLSTMCell(self.n_hidden, forget_bias=1.0, state_is_tuple=True)
          lstm_cell_2 = tf.contrib.rnn.BasicLSTMCell(self.n_hidden, forget_bias=1.0, state_is_tuple=True)
          lstm_cell_3 = tf.contrib.rnn.BasicLSTMCell(self.n_hidden, forget_bias=1.0, state_is_tuple=True)
          lstm_cell_4 = tf.contrib.rnn.BasicLSTMCell(self.n_hidden, forget_bias=1.0, state_is_tuple=True)
          lstm_cells = tf.contrib.rnn.MultiRNNCell([lstm_cell_1, lstm_cell_2, lstm_cell_3, lstm_cell_4], state_is_tuple=True)
          outputs, states = tf.contrib.rnn.static_rnn(lstm_cells, _X, dtype=tf.float32)
          lstm_last_output = outputs[-1]
          return tf.matmul(lstm_last_output, _weights['out']) + _biases['out']
     def __init__(self):
          self.n_steps = 128 # 32 timesteps per series
          self.n_input = 10  #ch4.x,ch4.y,ch7.x,ch7.y,dist4_16,dist7_17
          self.n_hidden = 34 # Hidden layer num of features
          self.n_classes = 3
          self.global_step = tf.Variable(0, trainable=False)
          # Graph input/output
          self.x = tf.placeholder(tf.float32, [None, self.n_steps, self.n_input])
          self.y = tf.placeholder(tf.float32, [None, self.n_classes])
          # Graph weights
          self.weights = {
               'hidden': tf.Variable(tf.random_normal([self.n_input, self.n_hidden])), # Hidden layer weights
               'out': tf.Variable(tf.random_normal([self.n_hidden, self.n_classes], mean=1.0))
          }
          self.biases = {
              'hidden': tf.Variable(tf.random_normal([self.n_hidden])),
              'out': tf.Variable(tf.random_normal([self.n_classes]))
          }
          self.pred = self.LSTM_RNN(self.x, self.weights, self.biases)
          config = tf.ConfigProto()
          config.gpu_options.per_process_gpu_memory_fraction = 0.4
          self.sess = tf.Session(config,...)
          self.init = tf.global_variables_initializer()
          with tf.Session() as self.sess:
                  self.sess.run(self.init)
                  saver = tf.train.Saver()
                  saver.restore(self.sess, tf.train.latest_checkpoint('/data/saac/HumanActivity/TrainModels/'))
                  print("Model restored.")

我在运行代码时出错

File "/data/saac/HumanActivity/ActivityRecognition.py", line 55, in __init__
    self.sess = tf.Session(config,...)
  File "/home/user/venvcuda9_0/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1551, in __init__
    super(Session, self).__init__(target, graph, config=config)
  File "/home/user/venvcuda9_0/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 643, in __init__
    raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
TypeError: graph must be a tf.Graph, but got <class 'ellipsis'>
Exception ignored in: <bound method BaseSession.__del__ of <tensorflow.python.client.session.Session object at 0x7f6e7e402320>>
Traceback (most recent call last):
  File "/home/user/venvcuda9_0/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 736, in __del__
    if self._session is not None:
AttributeError: 'Session' object has no attribute '_session'

怎么了?

1 个答案:

答案 0 :(得分:2)

问题出在tf.Session的构造函数参数中。省略号被识别为图形的参数,这就是错误TypeError: graph must be a tf.Graph, but got <class 'ellipsis'>的含义。

替换

self.sess = tf.Session(config,...)

使用

self.sess = tf.Session(config=config)