tensorflow的奇怪行为

时间:2017-06-08 14:09:44

标签: python tensorflow keras

我正致力于实施自定义重量调节器。它涉及计算相关矩阵的轨迹。问题在于self.tr = trace_norm(self.cauchy_schwarz)。当我转向complex64并返回到float32时,问题就消失了。相关的堆栈跟踪并不是我见过的最有帮助的......

编辑:在以下代码中,trace_orm在gpu上运行时变为tf.trace(w)

堆栈跟踪:

E c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\stream_executor\cuda\cuda_event.cc:49] Error polling for event status: failed to query event: CUDA_ERROR_LAUNCH_FAILED
F c:\tf_jenkins\home\workspace\release-win\device\gpu\os\windows\tensorflow\core\common_runtime\gpu\gpu_event_mgr.cc:198] Unexpected Event status: 1
E c:\tf_je

代码:

class TracenormRegularizer(Regularizer):

    def __init__(self, lr, e_width=1000, modalities=6):
        self.lr = cast(lr,dtype='float32')
        self.modalities = cast(modalities,dtype='int32')
        self.e_width = cast(e_width,dtype='int32')

        self.uses_learning_phase = True

    def set_param(self, p):
        if hasattr(self, 'p'):
            raise Exception('Regularizers cannot be reused. '
                            'Instantiate one regularizer per layer.')
        self.p = p

    def __call__(self, loss):
        _, dim2 = K.eval(K.shape(self.p))
        self.We = K.transpose(reshape(self.p, [self.modalities, self.e_width * dim2]))
        if K.ndim(self.We) > 2:
            raise Exception('Tracenorm regularizer '
                            'is only available for dense '
                            'and embedding layers.')
        self.cauchy_schwarz = symsqrt(matmul(self.We, self.We))
        self.tr = cast(trace_norm(cast(self.cauchy_schwarz,dtype='complex64')),dtype='float32')
        self.psi = multiply_elemwise(1/self.tr,self.cauchy_schwarz)
        self.psi_inv = matrix_inverse(self.psi)
        # tr(A'B)=tr(AB') doing this to save memory
        self.correlations = matmul(matmul(self.We, self.psi_inv, transpose_a=False),self.We)
        self.delta = self.lr * trace_norm(self.correlations) #issue repeats here
        regularized_loss = loss + self.delta
        return K.in_train_phase(regularized_loss, loss)

0 个答案:

没有答案