找回密码
 nanjixiong2018注册

QQ登录

只需一步,快速开始

发表帖子

mxnet系列 tools 查看params的内容

[复制链接]
caffe自己有写查看模型的内容

mxnet自己也写了一个
  1. import mxnet as mx
  2. import pdb
  3. def load_checkpoint(prefix, epoch):
  4.     """
  5.     Load model checkpoint from file.
  6.     :param prefix: Prefix of model name.
  7.     :param epoch: Epoch number of model we would like to load.
  8.     :return: (arg_params, aux_params)
  9.     arg_params : dict of str to NDArray
  10.         Model parameter, dict of name to NDArray of net's weights.
  11.     aux_params : dict of str to NDArray
  12.         Model parameter, dict of name to NDArray of net's auxiliary states.
  13.     """
  14.     save_dict = mx.nd.load('%s-%04d.params' % (prefix, epoch))
  15.     arg_params = {}
  16.     aux_params = {}
  17.     for k, v in save_dict.items():
  18.         tp, name = k.split(':', 1)
  19.         if tp == 'arg':
  20.             arg_params[name] = v
  21.         if tp == 'aux':
  22.             aux_params[name] = v
  23.     return arg_params, aux_params


  24. def convert_context(params, ctx):
  25.     """
  26.     :param params: dict of str to NDArray
  27.     :param ctx: the context to convert to
  28.     :return: dict of str of NDArray with context ctx
  29.     """
  30.     new_params = dict()
  31.     for k, v in params.items():
  32.         new_params[k] = v.as_in_context(ctx)
  33.     #print new_params[0]
  34.     return new_params


  35. def load_param(prefix, epoch, convert=False, ctx=None):
  36.     """
  37.     wrapper for load checkpoint
  38.     :param prefix: Prefix of model name.
  39.     :param epoch: Epoch number of model we would like to load.
  40.     :param convert: reference model should be converted to GPU NDArray first
  41.     :param ctx: if convert then ctx must be designated.
  42.     :return: (arg_params, aux_params)
  43.     """
  44.     arg_params, aux_params = load_checkpoint(prefix, epoch)
  45.     if convert:
  46.         if ctx is None:
  47.             ctx = mx.cpu()
  48.         arg_params = convert_context(arg_params, ctx)
  49.         aux_params = convert_context(aux_params, ctx)
  50.     return arg_params, aux_params


  51. if __name__=='__main__':
  52.         result =  load_param('my_',1);
  53.         #pdb.set_trace()
  54.         print 'result is'
  55.         print result
  56.         print 'one of results is:'
  57.         print result[0]['fc2_weight'].asnumpy()
复制代码
转自 andeyeluguo

使用道具 举报 回复
您需要登录后才可以回帖 登录 | nanjixiong2018注册

本版积分规则