我想在TensorFlow培训期间提出一个问题。 目前,我有一个网络可以进行超分辨率,我们知道SR模型包括特征提取和上采样。
现在,我想训练一个可以支持多尺度时间的模型。这在pytorch中很简单,但是在TensorFlow中,似乎很难实现。我实现如下: 例如 型号:
x = feature extract(input)
out = tf.case(pred_fn_pairs=[(tf.equal(scale, 2),
lambda: upsamplex2(x)),
(tf.equal(scale, 3), lambda: upsamplex3(x))],
default=lambda: upsamplex4(x), exclusive=False)
我将比例尺设置为占位符,但是在训练过程中,似乎只有x2模型可以很好地训练。用 x3和x4模型的推断结果非常糟糕。在这种情况下,有人知道如何进行培训吗? 非常感谢你!