因此,我拥有在Ray Tune中需要使用的followin keras生成器函数,def g函数具有kwargs **,而我不知何故需要摆脱它们,因为Ray Tune不支持kwargs atm。
class Threadsafe_Iter:
"""Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
Can be run both on fit_generator(...workers=x,...use_multiprocessing=True/False)
If use_multiprocessing = True creates a generator per CPU
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(lock=None):
def wrap(f):
"""A decorator that takes a generator function and makes it thread-safe.
Taken from
http://anandology.com/blog/using-iterators-and-generators/
"""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw), lock=lock)
return g
return wrap
@threadsafe_generator(lock=None)
def custom_generator(data=train, batch_size=32, num_negatives=1):
user_input, item_input, labels = data_generator(data, num_negatives)
print('generator initiated')
X_train = np.array(user_input)
Y_train = np.array(item_input)
y_train = np.array(labels)
idx=0
while True:
for i in range(p['n_users']):
index = np.random.choice(len(labels))
X_train[i] = X_train[index]
Y_train[i] = Y_train[index]
y_train[i] = y_train[index]
X = [X_train[:batch_size], Y_train[:batch_size]]
y = y_train[:batch_size]
yield X, y
print('generator yielded a batch %d' % idx)
idx+=1