【9.2】tensorflow--tf.add_to_collection与get_collection
一、介绍
- tf.add_to_collection : 向当前计算图中添加张量集合
- tf.get_collection :从一个集合中取出全部张量,是一个列表
- tf.add_n:把一个列表的东西都依次加起来
二、参数说明
查看源代码 tensorflow/python/framework/ops.py
2.1 tf.add_to_collection(name, value)
- name:列表名。如果不存在,创建一个新的列表
- value:元素
查看源码,可以看到
def add_to_collection(self, name, value):
"""Stores `value` in the collection with the given `name`.
Note that collections are not sets, so it is possible to add a value to
a collection several times.
Args:
name: The key for the collection. The `GraphKeys` class
contains many standard names for collections.
value: The value to add to the collection.
"""
self._check_not_finalized()
with self._lock:
if name not in self._collections:
self._collections[name] = [value]
else:
self._collections[name].append(value)
tf.add_to_collection 的作用是将value以name的名称存储在收集器(self._collections)中
2.2 tf.get_collection(name,scope=None)
此函数获取列表
参数:
name:列表名
查看源码:
def get_collection(self, name, scope=None):
"""Returns a list of values in the collection with the given `name`.
This is different from `get_collection_ref()` which always returns the
actual collection list if it exists in that it returns a new list each time
it is called.
Args:
name: The key for the collection. For example, the `GraphKeys` class
contains many standard names for collections.
scope: (Optional.) A string. If supplied, the resulting list is filtered
to include only items whose `name` attribute matches `scope` using
`re.match`. Items without a `name` attribute are never returned if a
scope is supplied. The choice of `re.match` means that a `scope` without
special tokens filters by prefix.
Returns:
The list of values in the collection with the given `name`, or
an empty list if no value has been added to that collection. The
list contains the values in the order under which they were
collected.
""" # pylint: disable=g-doc-exception
with self._lock:
collection = self._collections.get(name, None)
if collection is None:
return []
if scope is None:
return list(collection)
else:
c = []
regex = re.compile(scope)
for item in collection:
if hasattr(item, "name") and regex.match(item.name):
c.append(item)
return c
可以看到scope用来正则匹配 collection中的元素的名字,起到一个过滤的作用
三、例子
例一
代码:
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print tf.get_collection('loss')
print sess.run(tf.add_n(tf.get_collection('loss')))
输出:
[<tensorflow.python.ops.variables.Variable object at 0x7f6b5d700c50>, <tensorflow.python.ops.variables.Variable object at 0x7f6b5d700c90>]
[ 2.]
例二
代码:
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
v1 = tf.get_variable('v1', shape=[3], initializer=tf.ones_initializer())
v2 = tf.get_variable('v2', shape=[5], initializer=tf.random_uniform_initializer(maxval=-1., minval=1., seed=0))
# 向当前计算图中添加张量集合
tf.add_to_collection('v', v1)
tf.add_to_collection('v', v2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 返回当前计算图中手动添加的张量集合
v = tf.get_collection('v')
print v
print v[0].eval()
print v[1].eval()
输出
[<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>]
[ 1. 1. 1.]
[ 0.79827476 -0.9403336 -0.69752836 0.90343738 0.90295386]
参考资料
药企,独角兽,苏州。团队长期招人,感兴趣的都可以发邮件聊聊:tiehan@sina.cn
个人公众号,比较懒,很少更新,可以在上面提问题,如果回复不及时,可发邮件给我: tiehan@sina.cn
