Pytorch张量基本、切片操作总结

最近在学习深度学习的过程中发现pytorch中的切片操作又忘记了很多,于是专门重新进行梳理。以下演示内容为jupyter notebook.

(一)一维向量的操作

先从最基本的一维向量看起,一维向量的操作其实很像numpy一维数组,基本定义如下:

[起始索引:结束索引:步长]

这里要注意几点:

1.默认步长为1

2.起始索引:结束索引 是一个左闭右开区间,即结束索引的值不取

3.有反向索引,具体如下:

Pytorch张量基本、切片操作总结_第1张图片

1.导入包,创建行向量

import torch
# 创建一个行向量
x=torch.arange(12)
x
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

2.取某一个索引元素

# 取倒数第二个元素
x[-2]
tensor(10)

3.正常切片操作

# [起始2:起始8:步长2](左闭右开)
x[2:8:2]
tensor([2, 4, 6])

4.反向读取

# 取倒数第4-倒数第1(左闭右开)
x[-4:-1]
tensor([ 8,  9, 10])

 5.(注意)pytorch不支持反向获取序列

注意理解前面的反向索引,反向索引也要正向读取,比如x[-1:-3]也是不行的

# 不支持逆向读取
x[-4:-1:-1]

6.默认区间

# 默认区间
x[:]
x[::]
x[::1]

以上三种输出都一样:

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

(二)多维张量的基本操作

对于一维向量而言,操作使用冒号:,而对于多维张量,维度与维度之间区分使用逗号,比如对二维张量,x[ , ]。为避免混乱,以下用二维张量演示,可以理解为矩阵。

1.改变张量的形状-reshape

# 改变张量的形状
x=x.reshape(3,4)
x
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

2.自适应维度变化-在相应维度设-1 

# 自适应维度变化
x=x.reshape(3,-1)
x
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

这里就是3×-1=12,-1算出来为4 

3.获得张量中元素的总数-numel

# 获得张量中元素的总数
x.numel()

 结果为3×4=12

4.没有逗号,在第一维操作(dim=0,维度的下标也从0开始)

取出x的dim=0的第一个:

# 没有逗号,在第一维度dim=0操作
# 取出x的dim=0的第一个(这里就是第一行)
x[0]
tensor([0, 1, 2, 3])

取【1:2)行: 

# 没有逗号,在第一维度dim=0操作
# 取【1:2),即1行到2行(不包括2行,注意有0行)
x[1:2]
tensor([[4, 5, 6, 7]])

 将【1:3)行的元素全部换成5:

# 没有逗号,在第一维度dim=0操作
# 将【1:3)行的元素全部换成5
x[1:3]=5
x
tensor([[0, 1, 2, 3],
        [5, 5, 5, 5],
        [5, 5, 5, 5]])

4.有逗号,按照逗号区分维度间的操作

# 有逗号,在dim=0和dim=1上分别操作
# 先取【1:3)行,在此基础上取【2:4)列
x[1:3,2:4]
tensor([[ 6,  7],
        [10, 11]])
# 有逗号,在dim=0和dim=1上分别操作
# 先取【0:2)行,列上默认(即全选)
x[0:2,:]
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

 反向顺序:

# 有逗号,在dim=0和dim=1上分别操作
# 行上默认,列取【-3,-1)列
x[:,-3:-1]
tensor([[ 1,  2],
        [ 5,  6],
        [ 9, 10]])

...的用法:相当于这一维度默认

# 有逗号,在dim=0上取【1:2)行
# dim=1上默认(全取)
x[1:2,...]
tensor([[4, 5, 6, 7]])

 5.张量连结-注意需要拼接的维度一定要相同

# 张量连结
X=torch.arange(12,dtype=torch.float32).reshape((3,4))
Y=torch.tensor([[2.0,1,4,3],[1,2,3,4],[4,3,2,1]])
torch.cat((X,Y),dim=0),torch.cat((X,Y),dim=1)
(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [ 2.,  1.,  4.,  3.],
         [ 1.,  2.,  3.,  4.],
         [ 4.,  3.,  2.,  1.]]),
 tensor([[ 0.,  1.,  2.,  3.,  2.,  1.,  4.,  3.],
         [ 4.,  5.,  6.,  7.,  1.,  2.,  3.,  4.],
         [ 8.,  9., 10., 11.,  4.,  3.,  2.,  1.]]))

  6.广播机制

# 广播机制
a=torch.arange(3).reshape((3,1))
b=torch.arange(2).reshape((1,2))
a,b
(tensor([[0],
         [1],
         [2]]),
 tensor([[0, 1]]))
# a,b都会先变成3×2矩阵再相加
a+b
tensor([[0, 1],
        [1, 2],
        [2, 3]])

(三)更高维的演示

# 先随机生成一个三维张量
x=torch.rand(4,3,5)
x
tensor([[[0.3699, 0.1947, 0.6766, 0.5857, 0.2937],
         [0.2248, 0.6221, 0.1842, 0.2236, 0.4396],
         [0.2535, 0.6115, 0.8398, 0.5736, 0.2702]],

        [[0.6308, 0.1010, 0.0042, 0.5904, 0.2101],
         [0.7639, 0.8016, 0.8733, 0.4037, 0.7373],
         [0.1602, 0.9687, 0.0013, 0.2576, 0.6159]],

        [[0.2155, 0.2485, 0.9657, 0.4890, 0.0394],
         [0.8007, 0.6122, 0.2834, 0.7095, 0.9711],
         [0.5180, 0.8917, 0.0647, 0.6129, 0.7661]],

        [[0.9430, 0.1931, 0.3349, 0.0188, 0.1079],
         [0.1710, 0.2816, 0.2003, 0.3089, 0.1025],
         [0.1610, 0.9168, 0.7699, 0.4525, 0.4716]]])

x[:2,:,2:]

# 取第一维度【0:2),第二维度默认(全取),第三维度【2:】
x[:2,:,2:]
tensor([[[0.6766, 0.5857, 0.2937],
         [0.1842, 0.2236, 0.4396],
         [0.8398, 0.5736, 0.2702]],

        [[0.0042, 0.5904, 0.2101],
         [0.8733, 0.4037, 0.7373],
         [0.0013, 0.2576, 0.6159]]])

x[:-2,:,:2]

# 取第一维度【:-2),第二维度默认(全取),第三维度【:2)
x[:-2,:,:2]
tensor([[[0.3699, 0.1947],
         [0.2248, 0.6221],
         [0.2535, 0.6115]],

        [[0.6308, 0.1010],
         [0.7639, 0.8016],
         [0.1602, 0.9687]]])

 x[:2,:,0:5:2]

# 取第一维度【:2),第二维度默认(全取),第三维度【0:5:2】
x[:2,:,0:5:2]
tensor([[[0.3699, 0.6766, 0.2937],
         [0.2248, 0.1842, 0.4396],
         [0.2535, 0.8398, 0.2702]],

        [[0.6308, 0.0042, 0.2101],
         [0.7639, 0.8733, 0.7373],
         [0.1602, 0.0013, 0.6159]]])

(四)总结

根据逗号,来区分在哪一维度操作,根据冒号:来看这一维度的切片操作

把握这一句话,然后记住如果是默认那这一维全选就可以了。一般深度学习不会用到五维及以上,所以重点把握三维的各类变换就可以了。

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能,python)