TensorFlow2.0入门到进阶3.3 —— 基础数据类型API之strings与ragged_tensor

文章目录

  • 1、strings
    • 1.1 strings是什么?
    • 1.2 tf.strings 的优点是什么?
    • 1.3 程序实例
  • 2.ragged_tensor

1、strings

1.1 strings是什么?

看到strings,这个不是字符串吗?没错,它在这里的作用就是字符串,那tensorflow为什么还要单独拿出来呢,Python已经有字符串了呀!
其实,加入tf.strings的其中一个重要的作用是可以使字符串成为TensorFlow的第一公民,可以直接加入到模型的输入中,这对NLP领域是很有用的

1.2 tf.strings 的优点是什么?

之前在NLP中如果要将字符串进行计算,需要进行下面几步:

  • 1、首先需要将字符串分词,例如英文常见用空格、标点分词,中文使用分词器或者干脆按字分词
  • 2、其次需要计算一个词表,能将每个词对应到一个对应的数字上,一般还要加入一些例如[pad],[unk]等特殊符号
  • 3、在训练前将训练集的所有字符串经过上面的结果,都转换为数字符号。或者使用generator等技术在训练中流式转换

tf.strings的目的,就是我们为什么不能直接将字符串输入,避免上面的几步?这样做有几个好处

  1. 避免了很多多余的代码,比如额外的分词、计算词表等
  2. 保证模型的统一性,例如模型本身就包含了分词和符号转换,就可以直接把模型打包、发布
  3. 模型发布也可以直接用tensorflow serve等完成,避免第三方介入

1.3 程序实例

  1. 首先,创建一个字符串常量,之后可以用tf.strings.length查看其长度
    对于编码方式的转换,这里可以回归一下之前的ASCII码,a对应多少呀?没错,0x61,刚好是97,其实UTF-8就是ASCII的一种扩展
#strings
t = tf.constant("abcd")
print(t)
print(tf.strings.length(t))
print(tf.strings.length(t,unit="UTF8_CHAR"))
#将字符串t的编码方式从unicode转换为"UTF8
print(tf.strings.unicode_decode(t,"UTF8"))

结果:

tf.Tensor(b'abcd', shape=(), dtype=string)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor([ 97  98  99 100], shape=(4,), dtype=int32)
  1. 字符串数组的编码转换,对于集中编码方式,如果想详细了解,可以看:字符编码ANSI和ASCII区别、Unicode和UTF-8区别
# string array
t = tf.constant(["cafe","coffee","咖啡"])
print(tf.strings.length(t,"UTF8_CHAR"))
r = tf.strings.unicode_decode(t,"UTF8")
print(t)
print(r)

结果:

tf.Tensor([4 6 2], shape=(3,), dtype=int32)
tf.Tensor([b'cafe' b'coffee' b'\xe5\x92\x96\xe5\x95\xa1'], shape=(3,), dtype=string)
<tf.RaggedTensor [[99, 97, 102, 101], [99, 111, 102, 102, 101, 101], [21654, 21857]]>

2.ragged_tensor

ragged_tensor的意思就是不完整的张量,中间允许有元素空缺,对于张量的概念可以看上一节内容:TensorFlow2.0入门到进阶3.2 —— 基础数据类型API之常量(constant)

  1. 创建及显示
# ragged_tensor  不完整的矩阵,如每行中元素多少不一这种不规则情况
r = tf.ragged.constant([[11,12],[21,22,23],[],[41]])
#index op
print(r)
print(r[1])
print(r[:1])
print(r[0,0])

结果:

<tf.RaggedTensor [[11, 12], [21, 22, 23], [], [41]]>
tf.Tensor([21 22 23], shape=(3,), dtype=int32)
<tf.RaggedTensor [[11, 12]]>
tf.Tensor(11, shape=(), dtype=int32)

但是,这里要注意,如果你选中的部分数据中有空缺元素,将被报错,如下所示:

print(r[2,0])
# InvalidArgumentError: slice index 0 of dimension 0 out of bounds. [Op:StridedSlice]
  1. 拼接
# ops on ragged tensor
r2 = tf.ragged.constant([[51,52],[],[71]])
print(tf.concat([r,r2],axis=0))
# 横向拼接
r3 = tf.ragged.constant([[51,52],[53],[54],[55,56]])
print(tf.concat([r,r3],axis=1))
  1. 通过补0将不完整的张量变为普通张量:
#将ragged_tensor 变为普通的tensor,空余位将补0
r4 = r.to_tensor()
print(r4)

结果:

tf.Tensor(
[[11 12  0]
 [21 22 23]
 [ 0  0  0]
 [41  0  0]], shape=(4, 3), dtype=int32)

你可能感兴趣的:(TensorFlow2.0入门到进阶3.3 —— 基础数据类型API之strings与ragged_tensor)