wangwei created SINGA-413: ----------------------------- Summary: Hyper-parameter configuration API Key: SINGA-413 URL: https://issues.apache.org/jira/browse/SINGA-413 Project: Singa Issue Type: Improvement Reporter: wangwei
The current API for hyper-parameter configuration in Rafiki requires the model contributor to implement {code:java} def get_knob_config(self): { 'knobs': { 'hidden_layer_units': { 'type': 'int', 'range': [2, 128] }, ... } def init(self, knobs): self.hidden_layer_units = knobs.get(hidden_layer_units){code} The json style can be replaced by {code} def check(knobs, val): if knobs[0].value() == 4 and val == 4: return False else: return True def register_knobs(self): self.hidden_layer_units = Knob(32, kInt, range=(2, 128)) self.learning_rate = Knob(0.02, kExpFloat, range=(1e-1, 1e-4)) self.batch_size = Knob(8, kIntCat, range = [4, 8, 16])) self.num_epoch = Knob(4, kIntCat, range = [4, 8, 16], depends=self.batch_size, callback=check())) def train(self): for i in range(self.num_epoch.value()): ...{code} The worker's workflow is like {code:java} for trial in range(total_trials): model = Model() model.register_knobs() knobs = [x for x, y in model.__dict__.getitems() if isinstance(y, Knob)] if trial == 0: reigster_knobs_to_advisor(knobs) get_knobs_from_advisor(knobs){code} -- This message was sent by Atlassian JIRA (v7.6.3#76005)