get_shape_list

def get_shape_list(tensor, expected_rank=None, name=None)
参数:
tensor:一个需要返回shape的tf.Tensor
expected_rank:int或者是一个int的list。输入tensor期望的rank(也就是矩阵的维度),如果输入tensor的rank不等于这个数,或者不是这个list的元素之一,会抛异常。

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import tensorflow as tf
import six

def get_shape_list(tensor, expected_rank=None, name=None):
    """Returns a list of the shape of tensor, preferring static dimensions.

  Args:
    tensor: A tf.Tensor object to find the shape of.
    expected_rank: (optional) int. The expected rank of `tensor`. If this is
      specified and the `tensor` has a different rank, and exception will be
      thrown.
    name: Optional name of the tensor for the error message.

  Returns:
    A list of dimensions of the shape of tensor. All static dimensions will
    be returned as python integers, and dynamic dimensions will be returned
    as tf.Tensor scalars.
  """
    if name is None:
        name = tensor.name

    if expected_rank is not None:
        assert_rank(tensor, expected_rank, name)

    shape = tensor.shape.as_list()

    non_static_indexes = []
    for (index, dim) in enumerate(shape):
        if dim is None:
            non_static_indexes.append(index)

    if not non_static_indexes:
        return shape

    dyn_shape = tf.shape(tensor)
    for index in non_static_indexes:
        print(str(index))
        shape[index] = dyn_shape[index]
    return shape

def assert_rank(tensor, expected_rank, name=None):
    """Raises an exception if the tensor rank is not of the expected rank.

  Args:
    tensor: A tf.Tensor to check the rank of.
    expected_rank: Python integer or list of integers, expected rank.
    name: Optional name of the tensor for the error message.

  Raises:
    ValueError: If the expected shape doesn't match the actual shape.
  """
    if name is None:
        name = tensor.name

    expected_rank_dict = {}
    if isinstance(expected_rank, six.integer_types):
        expected_rank_dict[expected_rank] = True
    else:
        for x in expected_rank:
            expected_rank_dict[x] = True

    actual_rank = tensor.shape.ndims
    if actual_rank not in expected_rank_dict:
        scope_name = tf.get_variable_scope().name
        raise ValueError(
            "For the tensor `%s` in scope `%s`, the actual rank "
            "`%d` (shape = %s) is not equal to the expected rank `%s`" %
            (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))

if __name__ == '__main__':
    zero = tf.constant([0.0, 0.1, 0.2])
    zero_shape = get_shape_list(zero, expected_rank=[1, 3])
    print('zero_shape:')
    print(zero_shape)
    one = tf.constant([[0.0, 0.1, 0.2],[0.0, 0.1, 0.2]])
    one_shape = get_shape_list(one, expected_rank=[2, 3])
    print('one_shape:')
    print(one_shape)
    two = tf.compat.v1.placeholder(tf.int32, [32, 512])
    two_shape = get_shape_list(two, expected_rank=2)
    print('two_shape:')
    print(two_shape)

输出结果:

zero_shape:
[3]
one_shape:
[2, 3]
two_shape:
[32, 512]

你可能感兴趣的:(get_shape_list)