**
**
tf.where(tensor)
当只有一个输入时,输入为布尔型,其返回的是值为True的位置.
In [3]: a = tf.random.normal([3,3])#生成3行3列的随机数组
In [4]: a
Out[4]:
In [6]: mask = a>0 #返回>0的布尔值
In [7]: mask
Out[7]:
In [8]: idx = tf.where(mask)#返回值为True的位置索引
In [9]: idx
Out[9]:
In [10]: tf.gather_nd(a,idx)#根据索引从数组a中提取出数据
Out[10]:
In [11]: tf.boolean_mask(a,mask)
Out[11]:
tf.where(cond,A,B)输入为三个参数时,当cond中值为True时,取a对应值,否则取b中对应值
In [12]: mask
Out[12]:
In [13]: a = tf.ones([3,3])
In [14]: a
Out[14]:
In [15]: b = tf.zeros([3,3])
In [16]: b
Out[16]:
In [17]: c =tf.where(mask,a,b)
In [18]: c
Out[18]:
**
**
tf.scatter_nd(indices,updates,shape,name=None)
只能在数值全部为0的底板上面指定位置更新数据.
shape为底板的shape,类型为tensor
indices为更新位置的索引
updates为更新的新值
In [21]: shape = tf.constant([8])#指定底板的shape为(8,)
In [22]: shape
Out[22]:
In [23]: indices = tf.constant([[4],[3],[1],[7]])#指定更新位置的索引为4,3,1,7
In [24]: indices
Out[24]:
In [26]: updates = tf.constant([9,10,11,12])#指定更新的新值为[9,10,11,12]
In [27]: updates
Out[27]:
In [28]: tf.scatter_nd(indices,updates,shape)
Out[28]:
利用sactter_nd更新现有的tensor
将tensor A中需要更新位置的数据取出并更新到底板->得到A’
A = A-A’清零需要更新位置的数据
将需要更新的新数据更新到底板->得到A’’
A = A+A’'将数据更新到原tensor中.
二 meshgrid
In [39]: x = tf.linspace(-2.,2.,5)
In [40]: y = tf.linspace(-2.,2.,5)
In [41]: point_x,point_y = tf.meshgrid(x,y)
In [42]: point_x.shape
Out[42]: TensorShape([5, 5])
In [43]: point_y.shape
Out[43]: TensorShape([5, 5])
In [45]: points = tf.stack([point_x,point_y],axis=2)
In [46]: points
Out[46]:
In [47]: points.shape
Out[47]: TensorShape([5, 5, 2])