大家好,欢迎来到IT知识分享网。
mxnet symbol类定义:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/symbol.py
对于一个symbol,可分为non-grouped和grouped。且symbol具有输出,和输出属性。比如,对于Variable而言,其输入和输出就是它自己。对于c = a+b,c的内部有个_plus0 symbol,对于_plus0这个symbol,它的输入是a,b,输出是_plus0_output。
class Symbol(SymbolBase):
"""Symbol is symbolic graph of the mxnet."""
# disable dictionary storage, also do not have parent type.
# pylint: disable=no-member
其中,Symbol还不是最基础的类,Symbol类继承了SymbolBase这个类。
而SymbolBase这个类实际是在
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/symbol/_internal.py
中引用的,通过以下方式引用:
from .._ctypes.symbol import SymbolBase, _set_symbol_class, _set_np_symbol_class
而SymbolBase的定义是在:https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/_ctypes/symbol.py
这里暂时先不管SymbolBase,这应该是是python调用c++接口创建的一个类。
回到Symbol中来,对于mxnet符号式编程而言,定义的任何网络,或者变量,都是symbol类型,所以,了解这个类就显得很重要。
Symbol类中有几类函数:
1、普通函数
2、__xx__ 函数
3、@property 修饰的函数
4、函数名为xx,实际调用op.xx的函数
1、普通函数
attr
根据key返回symbol对应的属性字符串,只对non-grouped symbols起作用。
def attr(self, key):
"""Returns the attribute string for corresponding input key from the symbol.
list_attr
得到symbol的所有属性
def list_attr(self, recursive=False):
"""Gets all attributes from the symbol.
attr_dict
递归的得到symbol和孩子的属性
def attr_dict(self):
"""Recursively gets all attributes from the symbol and its children.
Example
-------
>>> a = mx.sym.Variable('a', attr={
'a1':'a2'})
>>> b = mx.sym.Variable('b', attr={
'b1':'b2'})
>>> c = a+b
>>> c.attr_dict()
{
'a': {
'a1': 'a2'}, 'b': {
'b1': 'b2'}}
_set_attr
通过key-value方式,对attr进行设置
def _set_attr(self, **kwargs):
"""Sets an attribute of the symbol.
For example. A._set_attr(foo="bar") adds the mapping ``"{foo: bar}"``
to the symbol's attribute dictionary.
get_internals
获取symbol的所有内部节点symbol,是一个group类型(包括输入,输出节点symbol)。如果我们想阶段一个network,应该获取它某内部节点的输出,这样才能作为新增加的symbol的输入。
def get_internals(self):
"""Gets a new grouped symbol `sgroup`. The output of `sgroup` is a list of
outputs of all of the internal nodes.
get_children
获取当前symbol输出节点的inputs
def get_children(self):
"""Gets a new grouped symbol whose output contains
inputs to output nodes of the original symbol.
list_arguments
列出当前symbol的所有参数(可以配合call对symbol进行改造)
def list_arguments(self):
"""Lists all the arguments in the symbol.
list_outputs
列出当前smybol的所有输出,如果当前symbol是grouped类型,回遍历输出每一个symbol的输出
def list_outputs(self):
"""Lists all the outputs in the symbol.
list_auxiliary_states
列出symbol中的辅助状态参数,比如BN
def list_auxiliary_states(self):
"""Lists all the auxiliary states in the symbol.
Example
-------
>>> a = mx.sym.var('a')
>>> b = mx.sym.var('b')
>>> c = a + b
>>> c.list_auxiliary_states()
[]
Example of auxiliary states in `BatchNorm`.
list_inputs
列出当前symbol的所有输入参数,和辅助状态,等价于 list_arguments和 list_auxiliary_states
def list_inputs(self):
"""Lists all arguments and auxiliary states of this Symbol.
2、__xx__函数
__repr__
对于gruop symbol,它是没有name属性的,print或者回车,结果就是其内部symbol节点的name
__iter__(self):
普通的symbol长度都只有1,只有Grouped 的symbol,长度才大于1:return (self[i] for i in range(len(self)))
算数及逻辑运算:
+,-,*, /,%,abs,**, 取负(-x),==,!=,>,>=,<,<=, # 使用时,要注意Broadcasting 是否支持
def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.symbol.abs(x, y)"""
return self.abs()
def __add__(self, other):
"""x.__add__(y) <=> x+y
其他
__copy__和__deep_copy__
通过deep_copy,创建一个深拷贝,返回输入对象的一个拷贝,包括它当前所有参数的当前状态,比如weight,bias等
__call__
表示symbol的实例是一个可调用对象。可以返回一个新的symbol,这个symbol继承了之前symbol的权重啥的,但是和之前的symbol是不同的对象,可以输入参数对symbol进行组合。
def __call__(self, *args, **kwargs):
"""Composes symbol using inputs. Returns ------- The resulting symbol. """
s = self.__copy__() # 这里对symbol实例做了一次深拷贝,返回的新的symbol
s._compose(*args, **kwargs) # 实际调用的_compose函数
return s
# 对当前的symbol进行编译,返回一个新的symbol,可以指定新symbol的name,其他输入参数必须是symbol类型
# 当前symbol的输入参数,可以通过 .list_arguments()获取
def _compose(self, *args, **kwargs):
"""Composes symbol using inputs. x._compose(y, z) <=> x(y,z) This function mutates the current symbol. Example ------- Returns ------- The resulting symbol. """
name = kwargs.pop('name', None)
if name:
name = c_str(name)
if len(args) != 0 and len(kwargs) != 0:
raise TypeError('compose only accept input Symbols \
either as positional or keyword arguments, not both')
这里,我改变了b,将其输入参数的x的值变为了tt。
__getitem__
如果symbol的长度只有1,那么返回的就是它的输出symbol,如果symbol长度>1,可以通过切片访问其输出symbol,返回的也是一个Group symbol。symbol可以分为non-grouped和grouped。
获取内部节点symbol还可以输入str,但输入的str必须属于list_outputs(),
def __getitem__(self, index):
"""x.__getitem__(i) <=> x[i] Returns a sliced view of the input symbol. Parameters ---------- index : int or str Indexing key """
output_count = len(self)
if isinstance(index, py_slice):
# 输入切片
if isinstance(index, string_types):
# 输入字符串
# Returning this list of names is expensive. Some symbols may have hundreds of outputs
output_names = self.list_outputs()
idx = None
for i, name in enumerate(output_names):
if name == index:
if idx is not None:
raise ValueError('There are multiple outputs with name \"%s\"' % index)
idx = i
if idx is None:
raise ValueError('Cannot find output that matches name \"%s\"' % index)
index = idx
symbol.py 除了Symbol这个类之外,还有游离在外的函数:
1、
def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
init=None, stype=None, **kwargs):
"""Creates a symbolic variable with specified name. # for back compatibility Variable = var # 调用 mx.sym.var和mx.sym.Variable 等价 2、 def Group(symbols, create_fn=Symbol): """Creates a symbol that contains a collection of other symbols, grouped together.
A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list
are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the
symbols in the list are of that type. A type error will be raised if a list of mixed
classic and numpy symbols are provided.
Example
-------
>>> a = mx.sym.Variable('a')
>>> b = mx.sym.Variable('b')
>>> mx.sym.Group([a,b])
<Symbol Grouped>
Parameters
----------
symbols : list
List of symbols to be grouped.
3、
def load(fname):
"""Loads symbol from a JSON file. You also get the benefit being able to directly load/save from cloud storage(S3, HDFS). Returns ------- sym : Symbol The loaded symbol. See Also -------- Symbol.save : Used to save symbol into file. # 输入文件可以是hdfs文件 4、 数学相关函数,输入可为scalar或者是symbol def pow(base, exp): """Returns element-wise result of base element raised to powers from exp element.
base 和 exp可以是数字或者symbol
# def power(base, exp): # 实际调用pow
def maximum(left, right):
def minimum(left, right):
def hypot(left, right): # 返回直角三角形的斜边
def eye(N, M=0, k=0, dtype=None, **kwargs):
"""Returns a new symbol of 2-D shpae, filled with ones on the diagonal and zeros elsewhere. # 返回2D shape的symbol,对角线为1,其余位置为0 def zeros(shape, dtype=None, **kwargs): """Returns a new symbol of given shape and type, filled with zeros. # 返回一个shape的全0 symbol
def ones(shape, dtype=None, **kwargs):
"""Returns a new symbol of given shape and type, filled with ones. def full(shape, val, dtype=None, **kwargs): """Returns a new array of given shape and type, filled with the given value `val`.
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval. def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None): """Returns evenly spaced values within a given interval.
def linspace(start, stop, num, endpoint=True, name=None, dtype=None):
"""Return evenly spaced numbers within a specified interval. def histogram(a, bins=10, range=None, **kwargs): """Compute the histogram of the input data.
def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
"""Split an array into multiple sub-arrays.
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/14314.html