由于笔者在作斯坦福课程 CS231n 作业过程中遇到了很多不懂的地方,后来练习了别的代码才慢慢熟悉 numpy,然后渐渐才懂作业里的用法。网络上很多关于此课程作业的答案,但真正有详细解释的非常少。本文旨在通过浅显的例子让大家理解答案中有关 numpy 的用法,并且会把答案的逻辑讲清楚,避免读者像本人当初一窍不通时读到崩溃。
1. KNN
In this part, we use CIFAR-10 as our training data. For more efficient code execution in this exercise, we use only 5000 data (pictures) as our training data, and use only 500 to test. Since each pictue is 32x32 size with 3 channel, the shape of X_train
is (5000, 32, 32, 3)
, and the shape of X_test
is (500, 32, 32, 3)
.
Then, reshape each data into an array, so the shape of X_train
is (5000, 3072), and the shape of X-test
is (500,3072), where 3072 = 32*32*3
.
(1) Function compute_distances_two_loops
:
The logic in this part is quiet easy. It's just like write in python, without any special functions used in numpy.
def compute_distances_two_loops(X, X_train):
num_test = X.shape[0]
num_train = X_train.shape[0]
dists = np.zeros((num_test, num_train))
for i in range(num_test):
for j in range(num_train):
dists[i,j] = np.sqrt(np.sum((X_train[j,:] - X[i,:])**2))
return dists
dists[i,j] = np.sqrt(np.sum((X_train[j,:] - X[i,:])**2))
- A[i,:]-B[j,;]
>> A = [[1,2,3,4,5,6] [9,8,7,6,5,4]] >> B = [[11,12,13,14,15,16] [19,18,17,16,15,14]] >> A[0,:] [1,2,3,4,5,6] >> B[1,:] [19,18,17,16,15,14] >> A[:,3] [4,6] >> B[:,2] [13,17]
- np.sum()
>> A = [1,2,3] >> A**2 [1,4,9] >> np.sum(A**2) #1+4+9 14
(2) Function compute_distances_one_loop
:
This function is a little bit hard.
def compute_distances_one_loop(X_train, X):
num_test = X.shape[0]
num_train = X_train.shape[0]
dists = np.zeros((num_test, num_train))
for i in range(num_test):
dists[i,:] = np.sqrt(np.sum(np.square(X_train - X[i,:]), axis = 1))
return dists
The logic in this part is similar to compute_distances_two_loops
, but the use of numpy is little bit harder.
X_train - X[i,:]
>> a = np.arange(15).reshape(3,5) >> print(a) array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]) >> b = np.arange(5) >> print(b) array([0, 1, 2, 3, 4]) >> a-b array([[ 0, 0, 0, 0, 0], [ 5, 5, 5, 5, 5], [10, 10, 10, 10, 10]])
- np.square()
>> a = np.arange(15).reshape(3,5) >> np.square(a) == a**2 array([[ True, True, True, True, True], [ True, True, True, True, True], [ True, True, True, True, True]])
(3) Function compute_distances_no_loops
:
In this part, distances are computed by only matrix operations. Logic can be understood by expanding the distance formula. Here I will only show the steps and results of those matrix operations.
def compute_distances_no_loops(X_train, X):
num_test = X.shape[0] # 500
num_train = X_train.shape[0] # 5000
dists = np.zeros((num_test, num_train)) # 500x5000 maratix
test_sum=np.sum(np.square(X),axis=1)
train_sum=np.sum(np.square(X_train),axis=1)
inner_product=np.dot(X,X_train.T)
# Reshape the test_sum into ?x1 martix/vector, '-1' here can calculate the ? automatically
dists = np.sqrt(-2*inner_product+test_sum.reshape(-1,1)+train_sum)
return dists
- Use straight examples to show
# Use a and b to represent Trainning dataset and Test dataset >> a = np.arange(15).reshape(3,5) >> a array([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]) >> b = np.array([[11,12,13,14,15],[21,22,23,24,25]]) >> b array([[11, 12, 13, 14, 15], [21, 22, 23, 24, 25]])
- If with one loop:
# a as the trainning set, b as the test set num_test = b.shape[0] num_train = a.shape[0] dists = np.zeros((num_test, num_train)) for i in range(num_test): dists[i,:] = np.sqrt(np.sum(np.square(a - b[i,:]), axis = 1)) print(dists)
[[24.59674775 13.41640786 2.23606798] [46.95742753 35.77708764 24.59674775]]
- If without any loops:
>> test_sum = np.sum(np.square(a),axis=1) >> print(test_sum) [ 855 2655] >> train_sum = np.sum(np.square(b),axis=1) >> print(train_sum) array([ 30, 255, 730]) >> test_sum.reshape(-1,1) array([[ 855], [2655]]) >> test_sum.reshape(-1,1)+train_sum array([[ 885, 1110, 1585], [2685, 2910, 3385]]) >> inner_product=np.dot(b,a.T) >> print(inner_product) [[ 140 465 790] [ 240 815 1390]] >> dists = np.sqrt(-2*inner_product+test_sum.reshape(-1,1)+train_sum) >> print(dists) [[24.59674775 13.41640786 2.23606798] [46.95742753 35.77708764 24.59674775]]
The result is the same.