一个运行TensorFlow operation的类。会话包含以下两种开启方式
1 TensorFlow 使用 tf.Session 类来表示客户端程序(通常为 Python 程序,但也提供了使用其他语言的类似接口)与 C++ 运行时之间的连接
2 tf.Session 对象使用分布式 TensorFlow 运行时提供对本地计算机中的设备和远程设备的访问权限。
会话可能拥有的资源,如 tf.Variable,tf.QueueBase和tf.ReaderBase。当这些资源不再需要时,释放这些资源非常重要。因此,需要调用tf.Session.close会话中的方法,或将会话用作上下文管理器。以下两个例子作用是一样的:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def session_demo():
"""
会话演示
:return:
"""
a_t = tf.constant(10)
b_t = tf.constant(20)
# 不提倡直接运用这种符号运算符进行计算
# 更常用tensorflow提供的函数进行计算
# c_t = a_t + b_t
c_t = tf.add(a_t, b_t)
print("tensorflow实现加法运算:\n", c_t)
# 开启会话
# 传统的会话定义
# sess = tf.Session()
# sum_t = sess.run(c_t)
# print("sum_t:\n", sum_t)
# sess.close()
# 开启会话
with tf.Session() as sess:
# sum_t = sess.run(c_t)
# 想同时执行多个tensor
print(sess.run([a_t, b_t, c_t])) # 同时运行三个,同时也需要3个变量来接收
# 方便获取张量值的方法
# print("在sess当中的sum_t:\n", c_t.eval())
# 会话的图属性
print("会话的图属性:\n", sess.graph)
return None
if __name__ == '__main__':
session_demo()
运行结果:
# 运行会话并打印设备信息
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=True))
会话可以分配不同的资源在不同的设备上运行。
/job:worker/replica:0/task:0/device:CPU:0
device_type:类型设备(例如CPU,GPU,TPU)
示例代码:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def session_demo():
"""
会话演示
:return:
"""
a_t = tf.constant(10)
b_t = tf.constant(20)
# 不提倡直接运用这种符号运算符进行计算
# 更常用tensorflow提供的函数进行计算
# c_t = a_t + b_t
c_t = tf.add(a_t, b_t)
print("tensorflow实现加法运算:\n", c_t)
# 开启会话
# 传统的会话定义
# sess = tf.Session()
# sum_t = sess.run(c_t)
# print("sum_t:\n", sum_t)
# sess.close()
# 开启会话
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
# sum_t = sess.run(c_t)
# 想同时执行多个tensor
print(sess.run([a_t, b_t, c_t])) # 同时运行三个,同时也需要3个变量来接收
# 方便获取张量值的方法
# print("在sess当中的sum_t:\n", c_t.eval())
# 会话的图属性
print("会话的图属性:\n", sess.graph)
return None
if __name__ == '__main__':
session_demo()
使用tf.operation.eval()也可运行operation,但需要在会话中运行
# 创建图
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# 创建会话
sess = tf.Session()
# 计算C的值
print(sess.run(c))
print(c.eval(session=sess))
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
def session_run_demo():
"""
会话的run方法
:return:
"""
# 定义占位符
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
sum_ab = tf.add(a, b)
print("sum_ab:\n", sum_ab)
# 开启会话
with tf.Session() as sess:
print("占位符的结果:\n", sess.run(sum_ab, feed_dict={a: 3.0, b: 4.0}))
return None
if __name__ == '__main__':
session_run_demo()
请注意运行时候报的错误error:
RuntimeError:如果这Session是无效状态(例如已关闭)。 TypeError:如果fetches或者feed_dict键的类型不合适。 ValueError:如果fetches或feed_dict键无效或引用 Tensor不存在的键。
补充:tf.InteractiveSession:用于交互式上下文中的TensorFlow