5/28/2021

In Tensorflow, get the names of all the Tensors in a graph


To get all nodes in the graph: (type tensorflow.core.framework.node_def_pb2.NodeDef)

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

To get all ops in the graph: (type tensorflow.python.framework.ops.Operation)

all_ops = tf.get_default_graph().get_operations()

To get all variables in the graph: (type tensorflow.python.ops.resource_variable_ops.ResourceVariable)

all_vars = tf.global_variables()

To get all tensors in the graph: (type tensorflow.python.framework.ops.Tensor)

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

To get all placeholders in the graph: (type tensorflow.python.framework.ops.Tensor)

all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]

Tensorflow 2

To get the graph in Tensorflow 2, instead of tf.get_default_graph() you need to instantiate a tf.function first and access the graph attribute, for example:

graph = func.get_concrete_function().graph

where func is a tf.function

No comments:

Post a Comment