TensorFlow2.0:高阶操作

**

一 tf.where( )函数

**

tf.where(tensor)
当只有一个输入时,输入为布尔型,其返回的是值为True的位置.

In [3]: a = tf.random.normal([3,3])#生成3行3列的随机数组                                                                                                                                                    

In [4]: a                                                                                                                                                                                                   
Out[4]: 


In [6]: mask = a>0 #返回>0的布尔值                                                                                                                                                                          

In [7]: mask                                                                                                                                                                                                
Out[7]: 


In [8]: idx = tf.where(mask)#返回值为True的位置索引                                                                                                                                                         

In [9]: idx                                                                                                                                                                                                 
Out[9]: 


In [10]: tf.gather_nd(a,idx)#根据索引从数组a中提取出数据                                                                                                                                                    
Out[10]: 

In [11]: tf.boolean_mask(a,mask)                                                                                                                                                                            
Out[11]: 

tf.where(cond,A,B)输入为三个参数时,当cond中值为True时,取a对应值,否则取b中对应值

In [12]: mask                                                                                                                                                                                               
Out[12]: 


In [13]: a = tf.ones([3,3])                                                                                                                                                                                 

In [14]: a                                                                                                                                                                                                  
Out[14]: 


In [15]: b = tf.zeros([3,3])                                                                                                                                                                                

In [16]: b                                                                                                                                                                                                  
Out[16]: 


In [17]: c =tf.where(mask,a,b)                                                                                                                                                                              

In [18]: c                                                                                                                                                                                                  
Out[18]: 


**

二 tf.scatter_nd( )函数

**

tf.scatter_nd(indices,updates,shape,name=None)

只能在数值全部为0的底板上面指定位置更新数据.
shape为底板的shape,类型为tensor
indices为更新位置的索引
updates为更新的新值

In [21]: shape = tf.constant([8])#指定底板的shape为(8,)                                                                                                                                                     

In [22]: shape                                                                                                                                                                                              
Out[22]: 

In [23]: indices = tf.constant([[4],[3],[1],[7]])#指定更新位置的索引为4,3,1,7                                                                                                                               

In [24]: indices                                                                                                                                                                                            
Out[24]: 


In [26]: updates = tf.constant([9,10,11,12])#指定更新的新值为[9,10,11,12]                                                                                                                                   

In [27]: updates                                                                                                                                                                                            
Out[27]: 

In [28]: tf.scatter_nd(indices,updates,shape)                                                                                                                                                               
Out[28]: 

利用sactter_nd更新现有的tensor
将tensor A中需要更新位置的数据取出并更新到底板->得到A’
A = A-A’清零需要更新位置的数据
将需要更新的新数据更新到底板->得到A’’
A = A+A’'将数据更新到原tensor中.

二 meshgrid

In [39]: x = tf.linspace(-2.,2.,5)                                                                                                                                                                          

In [40]: y = tf.linspace(-2.,2.,5)                                                                                                                                                                          

In [41]: point_x,point_y = tf.meshgrid(x,y)                                                                                                                                                                 

In [42]: point_x.shape                                                                                                                                                                                      
Out[42]: TensorShape([5, 5])

In [43]: point_y.shape                                                                                                                                                                                      
Out[43]: TensorShape([5, 5])

In [45]: points = tf.stack([point_x,point_y],axis=2)                                                                                                                                                        

In [46]: points                                                                                                                                                                                             
Out[46]: 


In [47]: points.shape                                                                                                                                                                                       
Out[47]: TensorShape([5, 5, 2])

你可能感兴趣的:(tensorflow2.0)