我正在尝试使用Google的Jax机器学习库编写DC Gan网络。为此,我创建了对象来充当鉴别器和生成器,但是,当我测试鉴别器时,出现了错误:
TypeError: Argument '<__main__.Discriminator object at 0x7fdfa5c6ffd0>' of type <class '__main__.Discriminator'> is not a valid JAX type
我浏览了Jax github页面上的示例,并且从我看到的结果来看,那里的示例都没有使用对象,这使我假设可能无法在Jax中使用对象。但是,如果是这种情况,我真的不明白为什么无法使用对象,这会在将来实现吗?我只是天真地忽略了什么吗?
这是我的鉴别对象:
class Discriminator():
def __init__(self):
self.step_size = 0.0001
self.image_shape = (256,256,3)
self.params = []
num_layers = 6
num_filters = 64
filter_size = 4
self.params.append(create_conv_layer(3,
num_filters,
filter_size,
filter_size,
random.PRNGKey(0)))
for l in range(1, num_layers):
self.params.append(create_conv_layer(64*2**(l-1),
64*2**l,
filter_size,
filter_size,
random.PRNGKey(0)))
self.params.append(create_conv_layer(64*2**num_filters,
1,
filter_size,
filter_size,
random.PRNGKey(0)))
def predict(self):
activations = image
for w, b in params[:-1]:
outputs = conv_forward(activations,w,b,stride=2)
outputs = batch_normalization(outputs)
activations = leaky_relu(outputs)
final_w, final_b = params[-1]
return sigmoid(conv_forward(activations,final_w,final_b,))
def batched_predict(self, images):
shape = [None] + list(self.image_shape)
return vmap(self.predict, in_axes=shape)(self.params, images)
def loss(self, params, images, targets):
preds = self.batched_predict(params, images)
return -np.sum(preds * targets)
def accuracy(self, images, targets):
predicted_class = np.round(np.ravel(batched_predict(images)))
return np.mean(predicted_class == target_class)
@jit
def update(self, params, x, y):
grads = grad(self.loss)(params, x, y)
return [(w - self.step_size * dw, b - self.step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
我在这里更新参数:
num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
discrim = Discriminator()
params = discrim.params
print("lets-a-go!")
for epoch in range(num_epochs):
start_time = time.time()
for step in range(steps_per_epoch):
x, y = simple_data_generator(batch_size)
params = discrim.update(params, x, y)
epoch_time = time.time() - start_time
train_acc = discrim.accuracy(train_images, train_labels)
test_acc = discrim.accuracy(test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))