def forward_propagation(X):
W0=tf.get_variable("W0",shape=[2,2,2,3,16],
initializer=tf.contrib.layers.xavier_initializer(seed = 0))
W1=tf.get_variable("W1",[4,4,4,16,32],
initializer=tf.contrib.layers.xavier_initializer(seed = 0))
W2=tf.get_variable("W2",[4,4,4,32,64],
initializer=tf.contrib.layers.xavier_initializer(seed = 0))
b0=tf.Variable(tf.zeros([16]))
b1=tf.Variable(tf.zeros([32]))
b2=tf.Variable(tf.zeros([64]))
# input (20,250,250,3)
con1=tf.nn.conv3d(X,W0,strides=[1,1,1,1,1],padding="SAME")
A1=tf.nn.relu(con1+b0)
# conv1 output(20,256,256,16)
pol1=tf.nn.max_pool3d(A1,ksize=[1,4,8,8,1],strides = [1,2,8,8,1], padding = 'SAME')
# maxPol stride 2,8 , f=4,8 output (8,32,32,16)
# conv1 output(8,64,64,32)
con2=tf.nn.conv3d(pol1,W1,strides=[1,1,1,1,1],padding="SAME")
A2=tf.nn.relu(con2+b1)
# maxPol stride 4,4 , f=4,4 output (2,8,8,32)
pol2=tf.nn.max_pool3d(A2,ksize=[1,4,4,4,1],strides = [1,4,4,4,1], padding = 'SAME')
# conv1 output(2,8,8,128)
con3=tf.nn.conv3d(pol2,W2,strides=[1,1,1,1,1],padding="SAME")
A3=tf.nn.relu(con3+b2)
# maxPol stride 1,4 , f=1,4 output (2,2,2,128)
pol3=tf.nn.max_pool3d(A3,ksize=[1,2,4,4,1],strides = [1,2,4,4,1], padding = 'SAME')
# FLATTEN
pol3=tf.contrib.layers.flatten(pol3)
Y=tf.contrib.layers.fully_connected(pol3,512)
Y=tf.contrib.layers.fully_connected(Y,128)
Y=tf.contrib.layers.fully_connected(Y,51,activation_fn=None)
return Y
ops.reset_default_graph()
seed = 3
(m,n_f, n_H0, n_W0, n_C0) = (2000,20,256,256,3)#X_train.shape
n_y = 51#y_train.shape[1]
costs_train = []
costs_val = []
X, Y = create_placeholders(n_f,n_H0, n_W0, n_C0,n_y)
print(X)
print(Y)
Y_pred=forward_propagation(X)
cost = compute_cost(Y=Y,Z=Y_pred)
print(cost)
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost)
predict_op = tf.argmax(Y_pred, 1)
correct_prediction = tf.equal(predict_op, tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
init = tf.global_variables_initializer()
我正在尝试为视频分类构建 CNN3D 模型。但我无法弄清楚这个问题,或者我无法理解这个错误的来源:
ValueError:Shape必须为0级,但排名为1 ' RMSProp / update_W0 / ApplyRMSProp' (op:' ApplyRMSProp')输入 形状:[2,2,2,3,16],[2,2,2,3,16],[2,2,2,3,16],[1],[],[],[] , [2,2,2,3,16]。
我怎么能搞清楚?