package com.aseg.poiapp.filter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.log4j.Logger;
import org.geotools.filter.function.ClassificationFunction;
import org.geotools.filter.function.RangedClassifier;
import org.opengis.filter.capability.FunctionName;
import org.geotools.filter.capability.FunctionNameImpl;
import org.opengis.filter.expression.Expression;
import org.opengis.filter.identity.FeatureId;

import static org.geotools.filter.capability.FunctionNameImpl.*;

public class JenksNBFunction extends ClassificationFunction {

	private static final Logger logger = Logger.getLogger("JenksNBFunction");
	
	//private Map<FeatureId, Double> featureData;
	
	public static FunctionName NAME = new FunctionNameImpl("Jenks",
            RangedClassifier.class,
            parameter("value", Double.class),
            parameter("classes", Integer.class));

    public JenksNBFunction() {
        super(NAME);
    }
    
    /*
     * (non-Javadoc)
     * 
     * @see org.geotools.filter.function.ClassificationFunction#evaluate(java.lang.Object)
     */
    public Object evaluate(Object feature) {
        if (!(feature instanceof Map)) {
            return null;
        }
        return calculate((Map) feature);
    }

    /**
     * This is based on James' GeoTools1 code which seems to be based on
     * http://lib.stat.cmu.edu/cmlib/src/cluster/fish.f
     * 
     * @param feature
     * @return a RangedClassifier
     */
    private Object calculate(Map<FeatureId, Double> featureData) {
        Set<FeatureId> keys = featureData.keySet();
        ArrayList<Double> data = new ArrayList<Double>();
        try {
            for (FeatureId key : keys) {
                Double e = featureData.get(key);
                if (!e.isInfinite() && !e.isNaN())
                    data.add(e);

                logger.trace ("importing " + e);
            }
        } catch (NumberFormatException e) {
            return null; // if it isn't a number what should we do?
        }
        Collections.sort(data);
        final int k = getClasses();
        final int m = data.size();
        if (k == m) {
            logger.info("Number of classes (" + k + ") is equal to number of data points (" + m
                    + ") " + "unique classification returned");
            Comparable[] localMin = new Comparable[k];
            Comparable[] localMax = new Comparable[k];

            for (int id = 0; id < k - 1; id++) {

                localMax[id] = data.get(id + 1);
                localMin[id] = data.get(id);
            }
            localMax[k - 1] = data.get(k - 1);
            localMin[k - 1] = data.get(k - 1);
            return new RangedClassifier(localMin, localMax);
        }
        int[][] iwork = new int[m + 1][k + 1];
        double[][] work = new double[m + 1][k + 1];

        for (int j = 1; j <= k; j++) {
            // the first item is always in the first class!
            iwork[0][j] = 1;
            iwork[1][j] = 1;
            // initialize work matirix
            work[1][j] = 0;
            for (int i = 2; i <= m; i++) {
                work[i][j] = Double.MAX_VALUE;
            }
        }

        // calculate the class for each data item
        for (int i = 1; i <= m; i++) {
            // sum of data values
            double s1 = 0;
            // sum of squares of data values
            double s2 = 0;

            double var = 0.0;
            // consider all the previous values
            for (int ii = 1; ii <= i; ii++) {
                // index in to sorted data array
                int i3 = i - ii + 1;
                // remember to allow for 0 index
                double val = data.get(i3 - 1);
                // update running totals
                s2 = s2 + (val * val);
                s1 += val;
                double s0 = (double) ii;
                // calculate (square of) the variance
                // (http://secure.wikimedia.org/wikipedia/en/wiki/Standard_deviation#Rapid_calculation_methods)
                var = s2 - ((s1 * s1) / s0);
                // System.out.println(s0+" "+s1+" "+s2);
                // System.out.println(i+","+ii+" var "+var);
                int ik = i3 - 1;
                if (ik != 0) {
                    // not the last value
                    for (int j = 2; j <= k; j++) {
                        // for each class compare current value to var + previous value
                        // System.out.println("\tis "+work[i][j]+" >= "+(var + work[ik][j - 1]));
                        if (work[i][j] >= (var + work[ik][j - 1])) {
                            // if it is greater or equal update classification
                            iwork[i][j] = i3 - 1;
                            // System.out.println("\t\tiwork["+i+"]["+j+"] = "+i3);
                            work[i][j] = var + work[ik][j - 1];
                        }
                    }
                }
            }
            // store the latest variance!
            iwork[i][1] = 1;
            work[i][1] = var;
        }
        if (logger.isTraceEnabled()) {
            for (int i = 0; i < m; i++) {
                String tmp = (i + ": " + data.get(i));
                for (int j = 2; j <= k; j++) {
                    tmp+=("\t" + iwork[i][j]);
                }
                logger.trace(tmp);
            }
        }
        // go through matrix and extract class breaks
        int ik = m - 1;

        Comparable[] localMin = new Comparable[k];
        Comparable[] localMax = new Comparable[k];
        localMax[k - 1] = data.get(ik);
        for (int j = k; j >= 2; j--) {
            logger.trace("index "+ik + ", class" + j);
            int id = (int) iwork[ik][j] - 1; // subtract one as we want inclusive breaks on the
                                             // left?
            
            localMax[j - 2] = data.get(id);
            localMin[j - 1] = data.get(id);
            ik = (int) iwork[ik][j] - 1;
        }
        localMin[0] = data.get(0);
        /*
         * for(int k1=0;k1<k;k1++) { System.out.println(k1+" "+localMin[k1]+" - "+localMax[k1]); }
         */
        return new RangedClassifier(localMin, localMax);
    }
}
