【9.2】tensorflow--tf.add_to_collection与get_collection

一、介绍

  • tf.add_to_collection : 向当前计算图中添加张量集合
  • tf.get_collection :从一个集合中取出全部张量,是一个列表
  • tf.add_n:把一个列表的东西都依次加起来

二、参数说明

查看源代码 tensorflow/python/framework/ops.py

2.1 tf.add_to_collection(name, value)

  • name:列表名。如果不存在,创建一个新的列表
  • value:元素

查看源码,可以看到

def add_to_collection(self, name, value):
    """Stores `value` in the collection with the given `name`.
    Note that collections are not sets, so it is possible to add a value to
    a collection several times.
    Args:
      name: The key for the collection. The `GraphKeys` class
        contains many standard names for collections.
      value: The value to add to the collection.
    """
    self._check_not_finalized()
    with self._lock:
      if name not in self._collections:
        self._collections[name] = [value]
      else:
        self._collections[name].append(value)

tf.add_to_collection 的作用是将value以name的名称存储在收集器(self._collections)中

2.2 tf.get_collection(name,scope=None)

此函数获取列表

参数:

name:列表名

查看源码:

 def get_collection(self, name, scope=None):
    """Returns a list of values in the collection with the given `name`.
    This is different from `get_collection_ref()` which always returns the
    actual collection list if it exists in that it returns a new list each time
    it is called.
    Args:
      name: The key for the collection. For example, the `GraphKeys` class
        contains many standard names for collections.
      scope: (Optional.) A string. If supplied, the resulting list is filtered
        to include only items whose `name` attribute matches `scope` using
        `re.match`. Items without a `name` attribute are never returned if a
        scope is supplied. The choice of `re.match` means that a `scope` without
        special tokens filters by prefix.
    Returns:
      The list of values in the collection with the given `name`, or
      an empty list if no value has been added to that collection. The
      list contains the values in the order under which they were
      collected.
    """  # pylint: disable=g-doc-exception
    with self._lock:
      collection = self._collections.get(name, None)
      if collection is None:
        return []
      if scope is None:
        return list(collection)
      else:
        c = []
        regex = re.compile(scope)
        for item in collection:
          if hasattr(item, "name") and regex.match(item.name):
            c.append(item)
        return c

可以看到scope用来正则匹配 collection中的元素的名字,起到一个过滤的作用

三、例子

例一

代码:

import tensorflow as tf;  
import numpy as np;  
import matplotlib.pyplot as plt;  

v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print tf.get_collection('loss')
    print sess.run(tf.add_n(tf.get_collection('loss')))

输出:

[<tensorflow.python.ops.variables.Variable object at 0x7f6b5d700c50>, <tensorflow.python.ops.variables.Variable object at 0x7f6b5d700c90>]
[ 2.]

例二

代码:

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf
v1 = tf.get_variable('v1', shape=[3], initializer=tf.ones_initializer())
v2 = tf.get_variable('v2', shape=[5], initializer=tf.random_uniform_initializer(maxval=-1., minval=1., seed=0))
# 向当前计算图中添加张量集合
tf.add_to_collection('v', v1)
tf.add_to_collection('v', v2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 返回当前计算图中手动添加的张量集合
    v = tf.get_collection('v')
    print v
    print v[0].eval()
    print v[1].eval()

输出

[<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>]
[ 1.  1.  1.]
[ 0.79827476 -0.9403336  -0.69752836  0.90343738  0.90295386]

参考资料

个人公众号,比较懒,很少更新,可以在上面提问题,如果回复不及时,可发邮件给我: tiehan@sina.cn

Sam avatar
About Sam
专注生物信息 专注转化医学