标签: python-3.x tensorflow2.x quantization-aware-training

我正在使用TensorFlow-2.2,tensorflow_model_optimization和Python 3.8。我正在尝试量化和训练包含稀疏度为91.3375%的LeNet-300-100密集神经网络。这意味着91.3375%的权重为零。我一直关注Quantization TF tutorial,我想训练这样一个稀疏的网络,该网络已使用 tf.GradientTape 而不是 q_aware_model.fit()进行了量化。 >

如果您查看example code,则相关的代码段为:

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# 'quantize_model' requires recompilation-
    optimizer = tf.keras.optimizers.Adam(lr = 0.0012),

# Define 'train_one_step()' and 'test_step()' functions here-
def train_one_step(model, mask_model, optimizer, x, y):
    Function to compute one step of gradient descent optimization
    with tf.GradientTape() as tape:
        # Make predictions using defined model-
        y_pred = model(x)

        # Compute loss-
        loss = loss_fn(y, y_pred)
    # Compute gradients wrt defined loss and weights and biases-
    grads = tape.gradient(loss, model.trainable_variables)
    # type(grads)
    # list
    # List to hold element-wise multiplication between-
    # computed gradient and masks-
    grad_mask_mul = []
    # Perform element-wise multiplication between computed gradients and masks-
    for grad_layer, mask in zip(grads, mask_model.trainable_weights):
        grad_mask_mul.append(tf.math.multiply(grad_layer, mask))
    # Apply computed gradients to model's weights and biases-
    optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

    # Compute accuracy-
    train_accuracy(y, y_pred)

    return None
def test_step(model, optimizer, data, labels):
    Function to test model performance
    on testing dataset
    predictions = model(data)
    t_loss = loss_fn(labels, predictions)

    test_accuracy(labels, predictions)

    return None

# Train model using 'GradientTape'-
# Initialize parameters for Early Stopping manual implementation-
# best_val_loss = 100
# loc_patience = 0
for epoch in range(num_epochs):
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
    # Reset the metrics at the start of the next epoch
    for x, y in train_dataset:
        train_one_step(q_aware_model, mask_model, optimizer, x, y)

    for x_t, y_t in test_dataset:
        test_step(q_aware_model, optimizer, x_t, y_t)

    template = 'Epoch {0}, Loss: {1:.4f}, Accuracy: {2:.4f}, Test Loss: {3:.4f}, Test Accuracy: {4:4f}'
    # 'i' is the index for number of pruning rounds-
    history_main[i]['accuracy'][epoch] = train_accuracy.result() * 100
    history_main[i]['loss'][epoch] = train_loss.result()
    history_main[i]['val_loss'][epoch] = test_loss.result()
    history_main[i]['val_accuracy'][epoch] = test_accuracy.result() * 100

        epoch + 1, train_loss.result(),
        train_accuracy.result()*100, test_loss.result(),
    # Count number of non-zero parameters in each layer and in total-
    # print("layer-wise manner model, number of nonzero parameters in each layer are: \n")
    model_sum_params = 0
    for layer in winning_ticket_model.trainable_weights:
        # print(tf.math.count_nonzero(layer, axis = None).numpy())
        model_sum_params += tf.math.count_nonzero(layer, axis = None).numpy()
    print("Total number of trainable parameters = {0}\n".format(model_sum_params))

    # Code for manual Early Stopping:
    if np.abs(test_loss.result() < best_val_loss) >= minimum_delta:
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = test_loss.result()
        # reset 'loc_patience' variable-
        loc_patience = 0
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement


--------------------------------------------------- ---------------------------- InvalidArgumentError错误回溯(最近的调用 最后) 19 train_dataset中的x,y为20: ---> 21 train_one_step(q_aware_model,mask_model,optimizer,x,y) 22 23

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / def_function.py 在通话中(自己,* args,** kwds) 578 xla_context.Exit() 其他579 -> 580结果= self._call(* args,** kwds) 581 582如果tracing_count == self._get_tracing_count():

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / def_function.py 在_call(self,* args,** kwds)中 642#解除成功,因此变量被初始化,我们可以运行 643#无状态功能。 -> 644返回self._stateless_fn(* args,** kwds) 第645章 646 canon_args,canon_kwds = \

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / function.py 在呼叫(自我,* args,** kwargs)2418中与self._lock:
2419 graph_function,args,kwargs = self._maybe_define_function(args,kwargs) -> 2420返回graph_function._filtered_call(args,kwargs)#pylint:disable = protected-access 2421 2422 @property

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / function.py 在_filtered_call(self,args,kwargs)1659 argskwargs。 1660“”“ -> 1661返回self._call_flat(1662(t为nest.flatten((args,kwargs),expand_composites = True)中的t)1663

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / function.py 在_call_flat(自己,args,captured_inputs,cancel_manager)中
1743 and execute_eagerly):1744#没有磁带 观看;跳至运行该功能。 -> 1745返回self._build_call_outputs(self._inference_function.call(1746
ctx,args,cancel_manager = cancellation_manager))1747
forward_backward = self._select_forward_and_backward_functions(

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / function.py 在通话中(self,ctx,args,cancel_manager) 591带有_InterpolateFunctionError(): 第592章 -> 593个输出= execute.execute( 第594章(二更) 595 num_outputs = self._num_outputs,

〜/ .local / lib / python3.8 / site-packages / tensorflow / python / eager / execute.py 在quick_execute(op_name,num_outputs,输入,attrs,ctx,name)中 57试试: 58 ctx.ensure_initialized() ---> 59张量= pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name, 60个输入,attrs,num_outputs) 61,除了core._NotOkStatusException如e:

InvalidArgumentError:var和grad形状不同[10] [100,10] [[节点Adam / Adam / update_4 / ResourceApplyAdam(定义为 :29)]] [Op:__ inference_train_one_step_20360]

错误可能源于输入操作。输入源 连接到节点Adam / Adam / update_4 / ResourceApplyAdam的操作: Mul_4(定义为:26)
顺序/ quant_dense_2 / BiasAdd / ReadVariableOp /资源(定义为 /home/arjun/.local/lib/python3.8/site-packages/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py:162)




