【Python】numpy广播broadcast与np.newaxis()函数详解

【Python】numpy广播broadcast与np.newaxis()函数详解

文章目录

  • 【Python】numpy广播broadcast与np.newaxis()函数详解
    • 1. 广播 broadcast
      • 计算对象数组形状相同
      • 计算对象数组形状不同
    • 2. numpy.newaxis 函数
      • 在行方向增加维度
      • 在列方向增加维度
      • 多维数组增加维度
    • 3. 广播broadcast与newaxis结合使用

1. 广播 broadcast

广播(broadcast)是Python numpy库中十分常用的功能之一。针对不同形状的数组,依据一定规则进行元素的数学运算,一般是“较小的数据在较大的数组上广播”,以便兼容不同的形状。

计算对象数组形状相同

当计算对象数数组形状相同,不会触发广播功能。
比如下面两个数组进行加法运算:

import numpy as np

a = np.array([[ 4,  7,  9], 
              [10,  2, 10]])
b = np.array([[ 3,  2,  6], 
              [ 8,  1,  5]])
a + b

输出结果如下,每个位置对应元素相加:

array([[ 7,  9, 15],
       [18,  3, 15]])

计算对象数组形状不同

当计算对象数组形状不同时,则会触发广播功能。

广播应用规则: 输入数组具有相同形状,或者其中一个数组某维度的尺寸为1。

如果不满足该规则,那么会抛出异常:ValueError: operands could not be broadcast together。

相同形状的数组运算上面已经说明,那么什么是“其中一个数组某维度的尺寸为1”?

比如下面两个数组:

>>> a = np.array([[ 4,  7,  9], 
...               [10,  2, 10],
...               [ 1,  4,  8]])
>>> b = np.array([[ 3,  2,  6]])
>>> print(a.shape, b.shape)
(3, 3) (1, 3)
>>> a + b
array([[ 7,  9, 15],
       [13,  4, 16],
       [ 4,  6, 14]])

可以看到,输出结果等于下面两个数组a和c相加:

>>> c = np.array([[ 3,  2,  6], 
...               [ 3,  2,  6], 
...               [ 3,  2,  6]])
>>> a + c
array([[ 7,  9, 15],
       [13,  4, 16],
       [ 4,  6, 14]])

也就是说,在计算时,将b在第一个维度上(原维度为1)进行了“拉伸”或者说“复制”,然后与a在元素级别相加。

如果b的第一个维度不为1,就会抛出异常:

>>> d = np.array([[ 3,  2,  6], 
...               [ 1,  4,  7]])
>>> d.shape
(2, 3)
>>> a + d
Traceback (most recent call last):
  File "", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (3,3) (2,3) 

矩阵的元素进行乘法运算也类似:

>>> a = np.array([[ 4,  7,  9], 
...               [10,  2, 10],
...               [ 1,  4,  8]])
>>> b = np.array([[ 3,  2,  6]])
>>> print(a.shape, b.shape)
(3, 3) (1, 3)
>>> a * b
array([[12, 14, 54],
       [30,  4, 60],
       [ 3,  8, 48]])

2. numpy.newaxis 函数

顾名思义,newaxis函数的功能是增加新的维度,将一维数组变为二维,二维数组变为三维等。
对于下面一维数组x0,可以使用newaxis在不同方向增加维度,使其成为形状不同的二维数组,进而继续增加维度。

>>> x0 = np.arange(4)
>>> x0
array([0, 1, 2, 3])
>>> x0.shape
(4,)

在行方向增加维度

在行方向增加维度,从一维数组变为二维数组。

>>> x1 = x[np.newaxis, :]
>>> x1
array([[0, 1, 2, 3]])
>>> x1.shape
(1, 4)

在列方向增加维度

在列方向增加维度,从一维数组变为另一种形状的二维数组。

>>> x2 = x[:, np.newaxis]
>>> x2
array([[0],
       [1],
       [2],
       [3]])
>>> x2.shape
(4, 1)

多维数组增加维度

对于多维数组,同样可以使用newaxis增加一个维度。

>>> y = np.arange(24)
>>> y0 = y.reshape(2,3,4) # 创建形状为(2,3,4)的三维数组
>>> y0
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> y1 = y0[:,np.newaxis, :, :] # 在第一维和第二维之间增加一个新维度,形成形状为(2,1,3,4)的数组
>>> y1
array([[[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]],


       [[[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]]])
>>> y1.shape
(2, 1, 3, 4)

3. 广播broadcast与newaxis结合使用

广播与newaxis结合使用,可以方便地操作两个数组进行各种操作,比如可以获取两个数组的外积:

>>> a = np.array([0, 10, 20, 30])
>>> b = np.array([ 1,  2,  3])
>>> print(a.shape, b.shape)
(4,) (3,)
>>> a[:, np.newaxis] * b # 利用newaxis得到外积
array([[ 0,  0,  0],
       [10, 20, 30],
       [20, 40, 60],
       [30, 60, 90]])
>>> np.outer(a,b) # 调用outer函数计算外积,结果与上面相同
array([[ 0,  0,  0],
       [10, 20, 30],
       [20, 40, 60],
       [30, 60, 90]])  

其中,a[:, np.newaxis]操作之后,创造了形状为(4,1)的数组。在与b进行乘法运算时,触发了广播机制。

a维度增加:

>>> a[:, np.newaxis]
array([[ 0],
       [10],
       [20],
       [30]])
>>> a[:, np.newaxis].shape
(4, 1)

在与b相乘时,a[:, np.newaxis]和b都根据对方维度进行了广播,相当于将形状为(4,1)的a[:, np.newaxis],变成了形状为(4,3)的a1,将形状为(3,)的一维数组b,变成了形状为(4,3)的二维数组b1,然后进行元素的乘法运算。
也就是说,将a矩阵维度为1的维度根据b的相应维度进行了“拉伸”,同样,将b矩阵维度为1的维度根据a的相应维度进行了“拉伸”,分别对齐到了各方向上最大的维度。

>>> a1 = np.array([[ 0,  0,  0],
...                [10, 10, 10],
...                [20, 20, 20],
...                [30, 30, 30]])
>>> b1 = np.array([[ 1,  2,  3],
...                [ 1,  2,  3],
...                [ 1,  2,  3],
...                [ 1,  2,  3]])
>>> a1 * b1
array([[ 0,  0,  0],
       [10, 20, 30],
       [20, 40, 60],
       [30, 60, 90]])

你可能感兴趣的:(python,python,numpy,开发语言)