// Somewhere in ann namespace...

namespace augmented {
    template<typename MatType = arma::mat,
             typename Controller,
             typename SimilarityCallable>
    class NTM {
    public:
        NTM(Controller controller, MatType memory, SimilarityCallable similarity);

        template<
            template<typename> class OptimizerType = mlpack::optimization::SGD
        >
        void Train(const MatType& predictors,
                   const MatType& responses,
                   OptimizerType<Controller> optimizer);

        void Predict(const MatType& predictors, const MatType& responses) const;

    };

    template<typename MatType = arma::mat,
             template Memory, // binary tree (with STL-like iterators for nodes) - or, in fact, std::vector with helper functions.
             template Controller, // LSTM in paper
             typename EmbedTransformationType,
             typename JoinTransformationType,
             typename SearchTransformationType,
             typename WriteTransformationType,
             >
    class HAMUnit {
    public:
        // As an idea - create HighwayNetwork class in ann namespace
        // before creating implementation for HAMUnit
        // (as recommended in arXiv paper for WRITE operation)
        HAMUnit(int memorySize, Controller& controller
                EmbedTransformationType& embed, JoinTransformationType& join,
                SearchTransformationType& search, WriteTransformationType& write);

        Memory::iterator Attention() const;
        MatType Output(MemoryType::iterator memoryCell) const;
        void Update(MemoryType::iterator momryCell);

        // Looks like this one is going to be a REINFORCE-style algorithm.
        void Train(const MatType& predictors,
                   const MatType& responses,
                   OptimizerType<Controller> optimizer,
                   double gamma = 0.95,
                   bool useHammingDistReward = false,
                   double entropyAlpha = 0);

        void Evaluate(const MatType& predictors, const MatType& responses) const;
    private:
        Memory memory;
        Controller controller;
    };

    // And now for benchmarking utilities.
    namespace tasks {
        class CopyTask {
        public:
            CopyTask();
            /* This one will generate an instance of RepeatCopy (just copy if nRepeats == 1) 
             * and run the model on it.
             * Expected implementation:
             * model.Train(*** the instance of CopyTask ***);
             * return model.Evaluate(*** another instance of the same CopyTask ***);
             */ 
            // As an idea - maybe introducting some score_t type?
            template<typename ModelType>
            double Evaluate(ModelType& model, int maxLength = 5, int nRepeats = 1);
        };

        class ReverseTask {
        public:
            ReverseTask();

            template<typename ModelType>
            double Evaluate(ModelType& model, int maxLength = 10);
        };

        class SortTask {
        public:
            SortTask();

            template<typename ModelType>
            double Evaluate(ModelType& model, int maxLength = 10);
        };

        class AddTask {
        public:
            AddTask();
            
            template<typename ModelType>
            double Evaluate(ModelType& model, int maxBits = 8); 
        };
    }
}
