Hi guys!

I recently started using theano and am struggling to implement custom 
gradient for stochastic node. Can anyone help me?

What I want is an op that produces one hot vector whose hot element is 
sampled from softmax distribution.
The op is not differentiable, but I want to "fake" as if its gradient is 
softmax's one ("straight through estimator").
Below is the minimum code that perform forward path, which raises 
DisconnectedInputError due to missing gradient.

import theano

import theano.tensor as T

import numpy as np

logits_values = np.random.uniform(-1, 1, size=3)

logits = theano.shared(logits_values, 'logits')

probabilities = T.nnet.softmax(logits)

print('probabilities', probabilities.eval())

# result: probabilities [[ 0.55155489  0.290773    0.15767211]]

random_streams = T.shared_randomstreams.RandomStreams()

index = random_streams.choice(size=(1,), a=3, p=probabilities[0])

samples = T.extra_ops.to_one_hot(index, logits.shape[-1])

print('samples', samples.eval())

# result: samples [[ 1.  0.  0.]]

# We want to use gradient of probabilities instead of samples!

samples_grad = T.grad(samples[0][0], logits)

# result: raise DisconnectedInputError

The node is not the final layer, so I can't use categorical cross entropy loss 
for training it.

I am trying to implement custom op (see attached stochastic_softmax.py), but it 
is not working in practice.

Since I have working expression for forward path, can I simply override 
gradient of existing expression?


import numpy as np
import theano
import theano.tensor as T

class StochasticSoftmax(theano.Op):
    def __init__(self, random_state=np.random.RandomState()):
        self.random_state = random_state

    def make_node(self, x):
        x = T.as_tensor_variable(x)
        return theano.Apply(self, [x], [x.type()])

    def perform(self, node, inputs, output_storage):
        # Gumbel-max trick
        x, = inputs
        z = self.random_state.gumbel(loc=0, scale=1, size=x.shape)
        indices = (x + z).argmax(axis=-1)
        y = np.eye(x.shape[-1], dtype=np.float32)[indices]
        output_storage[0][0] = y

    def grad(self, inp, grads):
        x, = inp
        g_sm, = grads

        sm = T.nnet.softmax(x)
        return [T.nnet.softmax_grad(g_sm, sm)]

    def infer_shape(self, node, i0_shapes):
        return i0_shapes

