如何编写没有** kwargs的Keras数据生成器线程函数?

时间:2018-09-26 06:03:21

标签: multithreading keras

因此,我拥有在Ray Tune中需要使用的followin keras生成器函数,def g函数具有kwargs **,而我不知何故需要摆脱它们,因为Ray Tune不支持kwargs atm。

   class Threadsafe_Iter:
       """Takes an iterator/generator and makes it thread-safe by
       serializing call to the `next` method of given iterator/generator.
       Can be run both on fit_generator(...workers=x,...use_multiprocessing=True/False)
       If use_multiprocessing = True creates a generator per CPU
       """
       def __init__(self, it):
           self.it = it
           self.lock = threading.Lock()

       def __iter__(self):
           return self

       def __next__(self):
           with self.lock:
               return self.it.__next__()


   def threadsafe_generator(lock=None):
    def wrap(f):
        """A decorator that takes a generator function and makes it thread-safe.
        Taken from
        http://anandology.com/blog/using-iterators-and-generators/
        """
        def g(*a, **kw):
            return threadsafe_iter(f(*a, **kw), lock=lock)
        return g
    return wrap

   @threadsafe_generator(lock=None)
   def custom_generator(data=train, batch_size=32, num_negatives=1):
       user_input, item_input, labels = data_generator(data, num_negatives)
       print('generator initiated')
       X_train = np.array(user_input)
       Y_train = np.array(item_input)
       y_train = np.array(labels)
       idx=0
       while True:
           for i in range(p['n_users']):
               index = np.random.choice(len(labels))
               X_train[i] = X_train[index]
               Y_train[i] = Y_train[index]
               y_train[i] = y_train[index]
               X = [X_train[:batch_size], Y_train[:batch_size]]
               y = y_train[:batch_size]
               yield X, y
               print('generator yielded a batch %d' % idx)
               idx+=1

0 个答案:

没有答案