我正在尝试使用手电筒来实现对角BiLSTM。我准备了CBAM模型,也有本文提到的空间门。我定义conv2d层的代码如下。
def Conv2D(name, input_dim, output_dim, filter_size, inputs, mask_type=None, he_init=False):
"""
inputs.shape: (batch size, height, width, input_dim)
mask_type: None, 'a', 'b'
output.shape: (batch size, height, width, output_dim)
"""
def uniform(stdev, size):
return np.random.uniform(
low=-stdev * np.sqrt(3),
high=stdev * np.sqrt(3),
size=size
).astype(numpy.float32)
filters_init = uniform(
1./np.sqrt(input_dim * filter_size * filter_size),
# output dim, input dim, height, width
(output_dim, input_dim, filter_size, filter_size)
)
if he_init:
filters_init *= lib.floatX(np.sqrt(2.))
if mask_type is not None:
filters_init *= lib.floatX(np.sqrt(2.))
filters = lib.param(
name+'.Filters',
filters_init
)
if mask_type is not None:
mask = np.ones(
(output_dim, input_dim, filter_size, filter_size),
dtype=np.float32
)
center = filter_size//2
for i in range(filter_size):
for j in range(filter_size):
if (j > center) or (j==center and i > center):
mask[:, :, j, i] = 0.
for i in range(N_CHANNELS):
for j in range(N_CHANNELS):
if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j):
mask[
j::N_CHANNELS,
i::N_CHANNELS,
center,
center
] = 0.
filters = filters * mask
inputs = inputs.permute(0, 3, 1, 2)
#print(inputs)
inps = torch.cat((torch.max(inputs, 1)[0].unsqueeze(1), torch.mean(inputs, 1).unsqueeze(1)), dim=1)
print(inps)
result = torch.nn.Conv2d(inps, 1, 7, stride=1)
biases = lib.param(
name+'.Biases',
np.zeros(output_dim, dtype=np.float32)
)
result = result + biases[None, :, None, None]
return result.permute(0, 2, 3, 1)
但是当我尝试将张量从torch.nn传递到conv2d层时,出现以下错误。
我试图在inps变量中打印张量,它们看起来像这样
tensor([[[[ 2.1475e+00, 2.2656e+00, 2.2285e+00, ..., 2.1634e+00,
1.9802e+00, 2.0768e+00],
[ 1.1065e-01, 9.3942e-02, 1.3884e-01, ..., 6.6712e-02,
9.9830e-02, 1.4429e-01]],
[[ 1.6910e+00, 1.5110e+00, 1.5579e+00, ..., 1.0768e+00,
1.5984e+00, 1.6736e+00],
[-8.2758e-02, -1.3184e-02, 5.0098e-02, ..., 2.2589e-03,
5.8106e-02, 7.3571e-03]]],
[[[ 2.2599e+00, 2.3655e+00, 2.2511e+00, ..., 2.5398e+00,
2.0128e+00, 1.9834e+00],
[ 9.4436e-02, 9.2293e-02, 1.5296e-01, ..., 5.4408e-02,
9.8131e-02, 1.3627e-01]],
[[ 1.6812e+00, 1.4765e+00, 1.5793e+00, ..., 1.1819e+00,
1.6332e+00, 1.6482e+00],
[-8.4564e-02, -1.1611e-02, 5.6016e-02, ..., 1.8283e-03,
5.0835e-02, 1.7487e-02]]],
[[[ 2.2437e+00, 2.4011e+00, 2.4019e+00, ..., 2.4042e+00,
2.0544e+00, 2.2374e+00],
[ 8.5540e-02, 9.6436e-02, 1.3502e-01, ..., 7.4034e-02,
1.0697e-01, 1.4066e-01]],
[[ 1.6898e+00, 1.4664e+00, 1.5747e+00, ..., 1.0820e+00,
1.6203e+00, 1.7650e+00],
[-7.5721e-02, -1.1245e-02, 5.4568e-02, ..., -5.5246e-03,
5.6962e-02, 9.4589e-03]]],
...,
[[[ 2.1059e+00, 2.3265e+00, 2.4311e+00, ..., 2.4247e+00,
1.9939e+00, 2.0536e+00],
[ 9.6550e-02, 1.0173e-01, 1.4329e-01, ..., 6.5086e-02,
1.0176e-01, 1.3881e-01]],
[[ 1.6375e+00, 1.5037e+00, 1.5442e+00, ..., 1.1553e+00,
1.6443e+00, 1.6747e+00],
[-8.6096e-02, -1.2071e-02, 5.6651e-02, ..., -1.4265e-03,
5.6373e-02, 1.3945e-02]]],
[[[ 2.3196e+00, 2.1638e+00, 2.1018e+00, ..., 2.3779e+00,
1.9471e+00, 1.9064e+00],
[ 9.8306e-02, 9.1398e-02, 1.5183e-01, ..., 6.4398e-02,
1.0629e-01, 1.4232e-01]],
[[ 1.6620e+00, 1.5137e+00, 1.5810e+00, ..., 1.0661e+00,
1.5326e+00, 1.6870e+00],
[-8.2823e-02, -8.9255e-03, 5.7103e-02, ..., -3.1661e-03,
5.9011e-02, 6.9708e-03]]],
[[[ 2.2586e+00, 2.3127e+00, 2.1126e+00, ..., 2.3692e+00,
2.0004e+00, 2.0361e+00],
[ 8.9194e-02, 9.7919e-02, 1.4491e-01, ..., 8.3260e-02,
1.0935e-01, 1.4324e-01]],
[[ 1.7048e+00, 1.5188e+00, 1.6082e+00, ..., 1.1044e+00,
1.5624e+00, 1.7280e+00],
[-8.0997e-02, -8.0655e-03, 6.0712e-02, ..., -4.3486e-03,
5.7900e-02, 8.8437e-03]]]], grad_fn=<CatBackward>)
RuntimeError:具有多个值的Tensor的布尔值含糊不清
请帮助我。我是新来的火炬手,不胜感激。
答案 0 :(得分:0)
您需要先声明转换层,然后将数据通过该层。
myconv = torch.nn.Conv2d(in_channels=2 out_channels=7, kernel_size=1, stride=1)
# NOTE: Switched in_channels 1->2, and added a kernel size argument ... please change this as needed for your purposes!
result = myconv(inps)
但是请在其他地方声明myconv
,因为每次都重新声明[04/07/2019 11:48:48 AM] WARNING: Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x09955250>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
[04/07/2019 11:48:49 AM] WARNING: Retrying (Retry(total=1, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x09BA9A70>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
[04/07/2019 11:48:50 AM] WARNING: Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00D804D0>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
[04/07/2019 11:48:51 AM] ERROR: Execution error: HTTPConnectionPool(host='127.0.0.1', port=55740): Max retries exceeded with url: /session/7bd001bbfaf61875a5579c84a0cce104/window (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00D80450>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',))
[04/07/2019 11:48:52 AM] WARNING: Retrying (Retry(total=2, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00D80530>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
[04/07/2019 11:48:53 AM] WARNING: Retrying (Retry(total=1, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00D80590>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
[04/07/2019 11:48:54 AM] WARNING: Retrying (Retry(total=0, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x00D805D0>: Failed to establish a new connection: [WinError 10061] No connection could be made because the target machine actively refused it',)': /session/7bd001bbfaf61875a5579c84a0cce104/window
不适用于训练-它将重置权重。不确定网络的结构,但是需要在其他地方初始化并在此处使用。