如何将形状较小的张量散射到较大的张量?

时间:2019-03-06 00:14:55

标签: python tensorflow keras keras-layer

我有一个形状为(1,4,4,1)的输入张量,我想将此输入与形状为(1,28,28,1)的conv层的输出连接起来,但是为此,我想将输入的每个值放到形状为(1,28,28,1)的新张量中,以便将该值放到7x7块的中间。如果我有一个np数组,我可以这样做:  w_expand[:, 3::7, 3::7] = wt_random 但是在keras中有张量和lambda层,我不知道该怎么办?您能帮我解决这个问题吗?

from keras.layers import Input, Concatenate, GaussianNoise,Dropout,BatchNormalization,MaxPool2D,AveragePooling2D
from keras.layers import Conv2D, AtrousConv2D
from keras.models import Model
from keras.datasets import mnist
from keras.callbacks import TensorBoard
from keras import backend as K
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as Kr
from keras.optimizers import SGD,RMSprop,Adam
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import numpy as np
import pylab as pl
import matplotlib.cm as cm
import keract
from matplotlib import pyplot
from keras import optimizers
from keras import regularizers

from tensorflow.python.keras.layers import Lambda;
#-----------------building w train---------------------------------------------
wt_random=np.random.randint(2, size=(49999,4,4))
w_expand=wt_random.astype(np.float32)
wv_random=np.random.randint(2, size=(9999,4,4))
wv_expand=wv_random.astype(np.float32)
x,y,z=w_expand.shape
w_expand=w_expand.reshape((x,y,z,1))
x,y,z=wv_expand.shape
wv_expand=wv_expand.reshape((x,y,z,1))

#-----------------building w test---------------------------------------------
w_test = np.random.randint(2,size=(1,4,4))
w_test=w_test.astype(np.float32)
w_test=w_test.reshape((1,4,4,1))


#-----------------------encoder------------------------------------------------
#------------------------------------------------------------------------------
wtm=Input((4,4,1))
image = Input((28, 28, 1))
conv1 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl1e',dilation_rate=(2,2))(image)
conv2 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl2e',dilation_rate=(2,2))(conv1)
conv3 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl3e',dilation_rate=(2,2))(conv2)
BN=BatchNormalization()(conv3)
encoded =  Conv2D(1, (5, 5), activation='relu', padding='same',name='encoded_I')(BN)

#-----------------------adding w---------------------------------------

I do not know how do I make wtmp with lambda layer in this part????
encoded_merged = Concatenate(axis=3)([encoded, wtmp])

#-----------------------decoder------------------------------------------------
#------------------------------------------------------------------------------
#deconv_input=Input((28,28,1),name='inputTodeconv')
#encoded_merged = Input((28, 28, 2))
deconv1 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl1d',dilation_rate=(2,2))(encoded_merged)
deconv2 = Conv2D(64, (5, 5), activation='relu', padding='same', name='convl2d',dilation_rate=(2,2))(deconv1)
deconv3 = Conv2D(64, (5, 5), activation='relu',padding='same', name='convl3d',dilation_rate=(2,2))(deconv2)
deconv4 = Conv2D(64, (5, 5), activation='relu',padding='same', name='convl4d',dilation_rate=(2,2))(deconv3)
BNd=BatchNormalization()(deconv4)

decoded = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='decoder_output')(BNd) 

model=Model(inputs=[image,wtm],outputs=decoded)

decoded_noise = GaussianNoise(0.5)(decoded)

#----------------------w extraction------------------------------------
convw1 = Conv2D(64, (5,5), activation='relu', name='conl1w')(decoded_noise)#24
convw2 = Conv2D(64, (5,5), activation='relu', name='convl2w')(convw1)#20
convw3 = Conv2D(64, (5,5), activation='relu' ,name='conl3w')(convw2)#16
convw4 = Conv2D(64, (5,5), activation='relu' ,name='conl4w')(convw3)#12
convw5 = Conv2D(64, (5,5), activation='relu', name='conl5w')(convw4)#8
convw6 = Conv2D(64, (5,5), activation='relu', name='conl6w')(convw5)#4
BNed=BatchNormalization()(convw6)
pred_w = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='reconstructed_W',dilation_rate=(2,2))(BNed)  
w=Model(inputs=[image,wtm],outputs=[decoded,pred_w])

w.summary()

0 个答案:

没有答案