Update: OpFromGraph didn't work well, but I managed to fix bug in
stochastic softmax.
What happened was that, theano tries to recompute the node depending on
other part of network, and because the node is stochastic, its forward path
and backward path became inconsistent.
So I made the op deterministic by taking uniform noise as explicit input to
the op. For using it in RNN and thus scan, I needed MRG_RandomStreams and
update dictionary returned by scan, but it is another story.
Thank you for helping me anyway! For future victim of similar problems, I
post my code for stochastic softmax.
import numpy as np
import theano
import theano.tensor as T
# Note: theano.tensor.shared_randomstreams.RandomStreams does not work in RNN.
# https://groups.google.com/d/msg/theano-users/DbvTgTqkT8o/raObyPqQX8YJ
# https://groups.google.com/d/msg/theano-users/FdPwm3517NY/SvznnPs83YEJ
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
class StochasticSoftmax(theano.Op):
"""
Given logits and uniform noise, produce one hot vector whose hot element is
sampled from softmax.
Gradient is approximated by softmax's one.
"""
def make_node(self, *inputs):
logits, noise = inputs
logits = T.as_tensor_variable(logits)
noise = T.as_tensor_variable(noise)
return theano.Apply(self, [logits, noise], [logits.type()])
def perform(self, node, inputs, output_storage):
logits, noise = inputs
# Gumbel-max trick
sampled_indices = np.argmax(logits + -np.log(-np.log(noise)), axis=-1)
one_hot_vectors = np.eye(logits.shape[-1],
dtype=np.float32)[sampled_indices]
output_storage[0][0] = one_hot_vectors
def grad(self, inputs, grads):
logits, noise = inputs
grad, = grads
logits_2d = logits.reshape((-1, logits.shape[-1]))
grad_2d = grad.reshape((-1, logits.shape[-1]))
softmax_output_2d = T.nnet.softmax(logits_2d)
grad_wrt_logits_2d = T.nnet.softmax_grad(grad_2d, softmax_output_2d)
grad_wrt_logits = grad_wrt_logits_2d.reshape(logits.shape)
error_comment = 'Gradient with respect to noise is not required for
backprop, so it is not implemented.'
grad_wrt_noise = T.grad_not_implemented(self, 1, noise,
comment=error_comment)
return [grad_wrt_logits, grad_wrt_noise]
def stochastic_softmax(logits):
random_streams = RandomStreams()
noise = random_streams.uniform(logits.shape)
return StochasticSoftmax()(logits, noise)
On Thursday, March 30, 2017 at 10:12:55 AM UTC+1, Yoh Okuno wrote:
>
> Thank you! I will try it.
>
> On Thursday, March 30, 2017 at 12:58:34 AM UTC+1, Pascal Lamblin wrote:
>>
>> You can try OpFromGraph with inline=True, and specify
>> override_gradients=... with the right expression.
>> This is still experimental.
>>
>> On Wed, Mar 29, 2017, nokun...@gmail.com wrote:
>> > Hi guys!
>> >
>> > I recently started using theano and am struggling to implement custom
>> > gradient for stochastic node. Can anyone help me?
>> >
>> > What I want is an op that produces one hot vector whose hot element is
>> > sampled from softmax distribution.
>> > The op is not differentiable, but I want to "fake" as if its gradient
>> is
>> > softmax's one ("straight through estimator").
>> > Below is the minimum code that perform forward path, which raises
>> > DisconnectedInputError due to missing gradient.
>> >
>> >
>> > import theano
>> >
>> > import theano.tensor as T
>> >
>> > import numpy as np
>> >
>> >
>> > logits_values = np.random.uniform(-1, 1, size=3)
>> >
>> > logits = theano.shared(logits_values, 'logits')
>> >
>> > probabilities = T.nnet.softmax(logits)
>> >
>> > print('probabilities', probabilities.eval())
>> >
>> > # result: probabilities [[ 0.55155489 0.2907730.15767211]]
>> >
>> >
>> > random_streams = T.shared_randomstreams.RandomStreams()
>> >
>> > index = random_streams.choice(size=(1,), a=3, p=probabilities[0])
>> >
>> > samples = T.extra_ops.to_one_hot(index, logits.shape[-1])
>> >
>> > print('samples', samples.eval())
>> >
>> > # result: samples [[ 1. 0. 0.]]
>> >
>> >
>> > # We want to use gradient of probabilities instead of samples!
>> >
>> > samples_grad = T.grad(samples[0][0], logits)
>> >
>> > # result: raise DisconnectedInputError
>> >
>> >
>> > The node