深度学习pytorch——索引与切片

indexing

import torch 
a = torch.rand(4,3,28,28)    # 表示4张28*28的rgb图
print(a[0].shape)            # a[0]获得第一张图片
print(a[0,0].shape)          # a[0,0]获得第一张图片的r图
print(a[0,0,2,4])            # 获得第一张图片第一个通道的一个像素点,因此得到的是一个标量

select first/last N 

# select first/last N
print(a[:2].shape)          # :2 => 0,1
print(a[:2,:1,:,:].shape)   # :1 => 0
print(a[:2,1:,:,:].shape)   # 1: => 1,2
print(a[:2,-1:,:,:].shape)  # -1: => 2

select by steps 

# select by steps
print(a[:,:,0:28:2,0:28:2].shape)   # 0:28:2 => 从0-28,步长为2
print(a[:,:,::2,::2].shape)         # ::2 => 从0-28,步长为2
# 总结
# 1. : => all
# 2. :n => 从最开始到n,不包括n
# 3. n: => 从n到最后
# 4. start:end => 从start到end,不包含end
# 5. start:end:steps => 从start到end,不包含end,步长为2

select by specific index 

# select by specific index
print(a.index_select(0,torch.tensor([0,2])).shape)  # index_select() 第一个参数是维度,第二个参数是具体的索引号,但是索引号必须是tensor,所以要使用torch.tensor()
print(a.index_select(2,torch.arange(28)).shape)     # torch.arange(28) => tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        # 18, 19, 20, 21, 22, 23, 24, 25, 26, 27])

 ...

# ...
print(a[...].shape)         # a[...] => a[:,:,:,:]
print(a[0,...].shape)       # a[0,...] => a[0,:,:,:]
print(a[:,1,...].shape)     # a[:,1,...] => a[:,1,:,:]
print(a[...,:2].shape)      # a[...,:2] => a[:,:,:,:2]

select by mask 

# select by mask
# .masked_select() 会将数据默认打平->之所以打平是因为当满足某条件的位数是根据内容才能确定的
x = torch.randn(3,4)
mask = x.ge(0.5)                            # 将大于等于0.5的数取为ture
print(mask)                                 # 掩码
print(torch.masked_select(x,mask))          # 根据掩码取数据和原shape无关
print(torch.masked_select(x,mask).shape)

select by flatten index 

# select by flatten index   打平
src = torch.tensor([[4,3,5],[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5])))        # 打平,取索引为0,2,5的数

你可能感兴趣的:(AI,深度学习,pytorch,人工智能)