Tensorflow-2中的cifar10数据训练问题

时间:2019-11-27 19:15:12

标签: python tensorflow keras

在tensorflow-2中训练cifar10数据时出现以下错误。我使用了这个tutorial

  

TypeError:预期float32传递给了操作“等于”的参数“ y”,   取而代之的是类型为“ str”的“ collections”。错误:预期为float32,得到了   的类型为“ str”的“ collections”。

我的代码如下:

open System.Runtime.CompilerServices
open System.Threading.Tasks

type TaskStep<'result> =
| Value of 'result
| AsyncValue of 'result Task
| Continuation of ICriticalNotifyCompletion * (unit -> 'result TaskStep)
and StateMachine<'a>(firstStep) as this =
    let methodBuilder = AsyncTaskMethodBuilder<'a Task>()
    let mutable continuation = fun () -> firstStep
    let nextAwaitable() =
        try
            match continuation() with
            | Value r ->
                methodBuilder.SetResult(Task.FromResult(r))
                null
            | AsyncValue t ->
                methodBuilder.SetResult(t)
                null
            | Continuation (await, next) ->
                continuation <- next
                await
        with
        | exn ->
            methodBuilder.SetException(exn)
            null
    let mutable self = this

    member __.Run() =
        methodBuilder.Start(&self)
        methodBuilder.Task

    interface IAsyncStateMachine with
        member __.MoveNext() =
            let mutable await = nextAwaitable()
            if not (isNull await) then
                methodBuilder.AwaitUnsafeOnCompleted(&await, &self)    
        member __.SetStateMachine(_) = 
            () 

type Binder<'out> =
    static member inline GenericAwait< ^abl, ^awt, ^inp
                                        when ^abl : (member GetAwaiter : unit -> ^awt)
                                        and ^awt :> ICriticalNotifyCompletion 
                                        and ^awt : (member get_IsCompleted : unit -> bool)
                                        and ^awt : (member GetResult : unit -> ^inp) >
        (abl : ^abl, continuation : ^inp -> 'out TaskStep) : 'out TaskStep =
            let awt = (^abl : (member GetAwaiter : unit -> ^awt)(abl))
            if (^awt : (member get_IsCompleted : unit -> bool)(awt)) 
            then continuation (^awt : (member GetResult : unit -> ^inp)(awt))
            else Continuation (awt, fun () -> continuation (^awt : (member GetResult : unit -> ^inp)(awt)))

module TaskStep =
    let inline bind f step : TaskStep<'a> =
        Binder<'a>.GenericAwait(step, f)

    let inline toTask (step: TaskStep<'a>) =
        try
            match step with
            | Value x -> Task.FromResult(x)
            | AsyncValue t -> t
            | Continuation _ as step -> StateMachine<'a>(step).Run().Unwrap()
        with
        | exn ->
            let src = new TaskCompletionSource<_>()
            src.SetException(exn)
            src.Task

module Task =
    let inline bind f task : Task<'a> =
        TaskStep.bind f task |> TaskStep.toTask

    let inline map f task : Task<'b> =
        bind (f >> Value) task

当我替换了compile和fit函数时,它可以工作。

    class Mymodel(tf.keras.Model):

        def __init__(self, class_size):
            """Initialize parameters and build model.
            """
            super(Mymodel, self).__init__()

            self.class_size =class_size
            self.conv1 = tf.keras.layers.Conv2D(32, kernel_size =3, strides =2, activation='relu')
            self.conv2 = tf.keras.layers.Conv2D(64, kernel_size =2, strides =2, activation='relu')
            self.conv3 = tf.keras.layers.Conv2D(64, kernel_size =2, strides =1, activation='relu')
            self.flat = tf.keras.layers.Flatten()
            self.d1 = tf.keras.layers.Dense(512, activation='relu')
            self.d2 = tf.keras.layers.Dense(128, activation='relu')
            self.fd =tf.keras.layers.Dense(self.class_size, activation='softmax') 

        def call(self, inputs):
            x = self.conv1(inputs)
            x = self.conv2(x)
            x = self.conv3(x)
            x = self.flat(x)
            x = self.d1(x)
            x = self.d2(x)
            return self.fd(x)

    model = Mymodel(10)

    train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
    train_images, test_images = train_images / 255.0, test_images / 255.0

    train_ds = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels)).shuffle(1000).batch(32)

    test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(32)

    # define the training and testing objects 

    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

    optimizer = tf.keras.optimizers.Adam()

    @tf.function
    def train_step(images, labels):
      with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
      gradients = tape.gradient(loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      loss(loss)
      accuracy(labels, predictions)


tf.function
def test_step(images, labels):
  predictions = model(images)
  t_loss = loss_object(labels, predictions)
  loss(t_loss)
  accuracy(labels, predictions)


def train():
    EPOCHS = 5

    for epoch in range(EPOCHS):
      for images, labels in train_ds:
        train_step(images, labels)

      for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

      template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
      print(template.format(epoch+1,
                        train_loss.result(),
                        train_accuracy.result()*100,
                        test_loss.result(),
                        test_accuracy.result()*100))

      # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
train()

非常适合任何帮助。

2 个答案:

答案 0 :(得分:1)

在损失函数中设置from_logits = True。

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)

它解决了错误!

答案 1 :(得分:0)

我认为您可以先使用input_shape参数。

self.conv1 = tf.keras.layers.Conv2D(32,kernel_size = 3,步幅= 2,激活='relu',input_shape =(w,h,n_channel)