samskalicky edited a comment on issue #19535:
URL:
https://github.com/apache/incubator-mxnet/issues/19535#issuecomment-727269397
Heres a complete, working solution on v1.8.x branch. Notice the new `save`
and `load` functions perform the full model export/reload of both model
architecture and params.
```python
import mxnet as mx
import json
class MyBlock(mx.gluon.nn.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
def add(self, block):
self._children[block.name + str(len(self._children))] = block
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block(*out)
return out
def save(self, prefix):
# create empty model structure
model = {}
def _save_cached_graphs(blk, index, structure):
# create new entry for this block
mdl = {'orig_name': blk.name}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index[0])] = mdl
if isinstance(blk, mx.gluon.nn.HybridBlock):
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
children = dict()
mdl['children'] = children
# recursively save children
for ch_name, child in blk._children.items():
index[0] += 1
# save child's original name in this block's map
children[child.name] = ch_name
_save_cached_graphs(child, index, mdl)
# save top-level block
index = [0]
_save_cached_graphs(self, index, model)
# save model
fp = open(prefix+'-model.json','w')
json.dump(model, fp)
fp.close()
# save params
self.save_parameters(prefix+'-model.params')
def load(self, prefix):
# load model json from file
fp = open(prefix+'-model.json')
model = json.load(fp)
fp.close()
def _load_cached_graphs(blk, index, log):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = log[name+str(index[0])]
# rename block to what it was when saved
blk._name = mdl['orig_name']
if isinstance(blk, mx.gluon.nn.HybridBlock):
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = mx.sym.load_json(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(mx.sym.load_json(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# rename params with updated block name
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params._params[p]
new_name = blk.name +'_'+ p[len(blk.params._prefix):]
blk.params._params.pop(p)
blk.params._params[new_name] = param
# recursively reload children
for ch_name, child in blk._children.items():
index[0] += 1
_load_cached_graphs(child, index, mdl)
# current set of child names
ch_names = list(blk._children.keys())
# original child names
children = mdl['children']
# loop and remap children with original names
for ch_name in ch_names:
child = blk._children[ch_name]
blk._children.pop(ch_name)
orig_name = children[child.name]
blk._children[orig_name] = child
# load top-level block
index = [0]
_load_cached_graphs(self, index, model)
# load params
self.load_parameters(prefix+'-model.params')
def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net
# create and initialize model
net = createNet()
net.initialize()
# run first inference to test
x = mx.nd.empty((1,10))
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net.hybridize()
out = net(x)
# save hybridized model
net.save('MyModel')
# create a new model, uninitialized
net = createNet()
# reload hybridized model
net.load('MyModel')
# run inference again
out = net(x)
```
And heres a complete, working solution on master branch:
```python
import mxnet as mx
import json
class MyBlock(mx.gluon.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
self.layers = []
def add(self, block):
self.layers.append(block)
self.register_child(block)
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block()(*out)
return out
def save(self, prefix):
# create empty model structure
model = {}
def _save_cached_graphs(blk, index, structure):
# create new entry for this block
mdl = {}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index[0])] = mdl
if isinstance(blk, mx.gluon.nn.HybridBlock):
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
# save param uuids
pmap = {}
mdl['params'] = pmap
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params[p]
pmap[p]=param._uuid
# recursively save children
for ch_name, child in blk._children.items():
index[0] += 1
_save_cached_graphs(child(), index, mdl)
# save top-level block
index = [0]
_save_cached_graphs(self, index, model)
# save model
fp = open(prefix+'-model.json','w')
json.dump(model, fp)
fp.close()
# save params
self.save_parameters('MyModel-model.params')
def load(self, prefix):
# load model json from file
fp = open(prefix+'-model.json')
model = json.load(fp)
fp.close()
def _load_cached_graphs(blk, index, structure):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = structure[name+str(index[0])]
if isinstance(blk, mx.gluon.nn.HybridBlock):
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = mx.sym.load_json(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(mx.sym.load_json(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# reload param uuids
pmap = mdl['params']
for p, uuid in pmap.items():
param = blk.params[p]
param._uuid = pmap[p]
# recursively reload children
for ch_name, child in blk._children.items():
index[0] += 1
_load_cached_graphs(child(), index, mdl)
# load top-level block
index = [0]
_load_cached_graphs(self, index, model)
# load params
self.load_parameters('MyModel-model.params')
def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net
# create and initialize model
net = createNet()
net.initialize()
# run first inference to test
x = mx.nd.random.randn(1,10)
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net.hybridize()
out = net(x)
# save hybridized model
net.save('MyModel')
# create a new model, uninitialized
net = createNet()
# reload hybridized model
net.load('MyModel')
# run inference again
out = net(x)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]