是否可以在Google的Jax机器学习库中使用对象

时间:2019-04-17 09:05:41

标签: python machine-learning

我正在尝试使用Google的Jax机器学习库编写DC Gan网络。为此,我创建了对象来充当鉴别器和生成器,但是,当我测试鉴别器时,出现了错误:

    TypeError: Argument '<__main__.Discriminator object at 0x7fdfa5c6ffd0>' of type <class '__main__.Discriminator'> is not a valid JAX type

我浏览了Jax github页面上的示例,并且从我看到的结果来看,那里的示例都没有使用对象,这使我假设可能无法在Jax中使用对象。但是,如果是这种情况,我真的不明白为什么无法使用对象,这会在将来实现吗?我只是天真地忽略了什么吗?

这是我的鉴别对象:

class Discriminator():
    def __init__(self):
        self.step_size = 0.0001
        self.image_shape = (256,256,3)
        self.params = []
        num_layers = 6
        num_filters = 64
        filter_size = 4
        self.params.append(create_conv_layer(3, 
                                             num_filters, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))
        for l in range(1, num_layers):
            self.params.append(create_conv_layer(64*2**(l-1), 
                                                 64*2**l, 
                                                 filter_size,   
                                                 filter_size, 
                                                 random.PRNGKey(0)))
        self.params.append(create_conv_layer(64*2**num_filters, 
                                             1, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))

    def predict(self):
        activations = image
        for w, b in params[:-1]:
            outputs = conv_forward(activations,w,b,stride=2)
            outputs = batch_normalization(outputs)
            activations = leaky_relu(outputs)
        final_w, final_b = params[-1]
        return sigmoid(conv_forward(activations,final_w,final_b,))

    def batched_predict(self, images):
        shape = [None] + list(self.image_shape)
        return vmap(self.predict, in_axes=shape)(self.params, images)

    def loss(self, params, images, targets):
        preds = self.batched_predict(params, images)
        return -np.sum(preds * targets)

    def accuracy(self, images, targets):
        predicted_class = np.round(np.ravel(batched_predict(images)))
        return np.mean(predicted_class == target_class)

    @jit
    def update(self, params, x, y):
        grads = grad(self.loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

我在这里更新参数:

num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
discrim = Discriminator()
params = discrim.params

print("lets-a-go!")
for epoch in range(num_epochs):
    start_time = time.time()
    for step in range(steps_per_epoch):
        x, y = simple_data_generator(batch_size)
        params = discrim.update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = discrim.accuracy(train_images, train_labels)
    test_acc = discrim.accuracy(test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

0 个答案:

没有答案