Heres a complete, working solution on v1.8.x branch. Notice the new `save` and
`load` functions perform the full model export/reload of both model
architecture and params.
```
import mxnet as mx
import json
class MyBlock(mx.gluon.nn.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
def add(self, block):
self._children[block.name + str(len(self._children))] = block
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block(*out)
return out
def save(self, prefix):
# create empty model structure
model = {}
def _save_cached_graphs(blk, index, structure):
# create new entry for this block
mdl = {'orig_name': blk.name}
# encode unique name based on block type and ID
name = type(blk).__name__.lower()
structure[name+str(index[0])] = mdl
if isinstance(blk, mx.gluon.nn.HybridBlock):
# save in/out formats
mdl['in_format'] = blk._in_format
mdl['out_format'] = blk._out_format
# save cached graph & input symbols
syms, out = blk._cached_graph
mdl_syms = []
for sym in syms:
mdl_syms.append(sym.tojson())
mdl['inputs'] = mdl_syms
mdl['symbol'] = out.tojson()
children = dict()
mdl['children'] = children
# recursively save children
for ch_name, child in blk._children.items():
index[0] += 1
# save child's original name in this block's map
children[child.name] = ch_name
_save_cached_graphs(child, index, mdl)
# save top-level block
index = [0]
_save_cached_graphs(self, index, model)
# save model
fp = open(prefix+'-model.json','w')
json.dump(model, fp)
fp.close()
# save params
self.save_parameters(prefix+'-model.params')
def load(self, prefix):
# load model json from file
fp = open(prefix+'-model.json')
model = json.load(fp)
fp.close()
def _load_cached_graphs(blk, index, log):
# get block name
name = type(blk).__name__.lower()
# lookup previous encoded name based on block type and ID
mdl = log[name+str(index[0])]
# rename block to what it was when saved
blk._name = mdl['orig_name']
if isinstance(blk, mx.gluon.nn.HybridBlock):
# restore in/out formats
blk._in_format = mdl['in_format']
blk._out_format = mdl['out_format']
# get saved symbol
out = mx.sym.load_json(mdl['symbol'])
syms = []
# recreate inputs for this symbol
for inp in mdl['inputs']:
syms.append(mx.sym.load_json(inp))
# reset cached_graph and active status
blk._cached_graph = (syms, out)
blk._active = True
# rename params with updated block name
pnames = list(blk.params.keys())
for p in pnames:
param = blk.params._params[p]
new_name = blk.name +'_'+ p[len(blk.params._prefix):]
blk.params._params.pop(p)
blk.params._params[new_name] = param
# recursively reload children
for ch_name, child in blk._children.items():
index[0] += 1
_load_cached_graphs(child, index, mdl)
# current set of child names
ch_names = list(blk._children.keys())
# original child names
children = mdl['children']
# loop and remap children with original names
for ch_name in ch_names:
child = blk._children[ch_name]
blk._children.pop(ch_name)
orig_name = children[child.name]
blk._children[orig_name] = child
# load top-level block
index = [0]
_load_cached_graphs(self, index, model)
# load params
self.load_parameters(prefix+'-model.params')
def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net
# create and initialize model
net = createNet()
net.initialize()
# run first inference to test
x = mx.nd.empty((1,10))
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net.hybridize()
out = net(x)
# save hybridized model
net.save('MyModel')
# create a new model, uninitialized
net = createNet()
# reload hybridized model
net.load('MyModel')
# run inference again
out = net(x)
```
--
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/incubator-mxnet/issues/19535#issuecomment-727269397