理解通过数组索引获得的矩阵

时间:2018-02-24 05:28:58

标签: python numpy

logistic regression code中列出的代码中,我看到了以下代码段。让我失望的是表达: probs[range(num_examples),y]。 有人能告诉我这个矩阵有什么尺寸吗?我的猜测是它是N*K N*K矩阵,但我不确定。感谢。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K,D))
y = np.zeros(N*K, dtype='uint8')
for j in xrange(K):
  ix = range(N*j,N*(j+1))
  r = np.linspace(0.0,1,N) # radius
  t = np.linspace(j*4,(j+1)*4,N) + np.random.randn(N)*0.2 # theta
  X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
  y[ix] = j

#Train a Linear Classifier

# initialize parameters randomly
W = 0.01 * np.random.randn(D,K)
b = np.zeros((1,K))

# some hyperparameters
step_size = 1e-0
reg = 1e-3 # regularization strength

# gradient descent loop
num_examples = X.shape[0]
for i in xrange(200):

  # evaluate class scores, [N x K]
  scores = np.dot(X, W) + b 

  # compute the class probabilities
  exp_scores = np.exp(scores)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True) # [N x K]

  # compute the loss: average cross-entropy loss and regularization
  corect_logprobs = -np.log(probs[range(num_examples),y])
  data_loss = np.sum(corect_logprobs)/num_examples
  reg_loss = 0.5*reg*np.sum(W*W)
  loss = data_loss + reg_loss
  if i % 10 == 0:

1 个答案:

答案 0 :(得分:3)

$col似乎是一维切片,其中:

  • probs[range(num_examples), y]是一个跨越样本长度的向量
  • y是1D向量,长度N * K