package com.nokia.search.address;

import java.io.IOException;
import java.io.Reader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Map;
import java.util.Random;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.icu.segmentation.ICUTokenizer;
import org.apache.lucene.document.BinaryDocValuesField;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.index.AtomicReaderContext;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.RAMDirectory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Version;

public class LuceneFloatSourceTest {
    
    private static class LuceneNumericFieldBasedGeometricDistance extends ValueSource {
        private final boolean withoutSquareRoot;
        private final String xField;
        private final String yField;

        public LuceneNumericFieldBasedGeometricDistance(final String xField, final String yField, boolean withoutSquareRoot) {
            this.xField = xField;
            this.yField = yField;
            this.withoutSquareRoot = withoutSquareRoot;
        }

        @Override
        public FunctionValues getValues(Map context,
                AtomicReaderContext readerContext) throws IOException {
            final NumericDocValues xValues = readerContext.reader().getNumericDocValues(this.xField);
            final NumericDocValues yValues = readerContext.reader().getNumericDocValues(this.yField);
            if (withoutSquareRoot) {
                return new FunctionValues() {
                    
                    @Override
                    public float floatVal(int doc) {
                        final long y = yValues.get(doc);
                        final long x = xValues.get(doc);
                        return (float) (y*y + x*x);
                    }
    
                    @Override
                    public String toString(int doc) {
                        return "geometric length function value source";
                    }
                };
            }
            return new FunctionValues() {
                
                @Override
                public float floatVal(int doc) {
                    final long y = yValues.get(doc);
                    final long x = xValues.get(doc);
                    return (float) Math.sqrt(y*y + x*x);
                }

                @Override
                public String toString(int doc) {
                    return "geometric length function value source";
                }
            };
        }

        @Override
        public boolean equals(Object o) {
            return false;
        }

        @Override
        public int hashCode() {
            return 0;
        }

        @Override
        public String description() {
            return "";
        }
    }
    
    private static class LuceneBinaryFieldGeometricDistance extends ValueSource {

        private final String field;
        private final boolean withoutSquareRoot;

        public LuceneBinaryFieldGeometricDistance(String field, boolean withoutSquareRoot) {
            this.field = field;
            this.withoutSquareRoot = withoutSquareRoot;
            
        }

        @Override
        public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException {
            final BinaryDocValues binaryDocValues = readerContext.reader().getBinaryDocValues(this.field);
            
            if (withoutSquareRoot) {
                return new FunctionValues() {
                    
                    final BytesRef br;
                    
                    {
                        br = new BytesRef(new byte[8]);
                    }
                    
                    @Override
                    public float floatVal(int doc) {
                        binaryDocValues.get(doc, br);
                        
                        final float x = readFloat(br.bytes, br.offset);
                        final float y = readFloat(br.bytes, br.offset+4);
                        //System.out.println("x: "+x+", y: "+y);
                        return (float) (y*y + x*x);
                    }
    
                    @Override
                    public String toString(int doc) {
                        return "geometric length function value source";
                    }
                };
            }
            return new FunctionValues() {
                
                final BytesRef br;
                
                {
                    br = new BytesRef(new byte[8]);
                }
                
                @Override
                public float floatVal(int doc) {
                    binaryDocValues.get(doc, br);
                    
                    final float x = readFloat(br.bytes, br.offset);
                    final float y = readFloat(br.bytes, br.offset+4);
                    //System.out.println("x: "+x+", y: "+y);
                    return (float) Math.sqrt(y*y + x*x);
                }

                @Override
                public String toString(int doc) {
                    return "geometric length function value source";
                }
            };
        }
        
        static float readFloat(byte[] bytes, int offset) {
            return Float.intBitsToFloat(readInt(bytes, offset));
        }
        
        static int readInt(byte[] bytes, int offset) {
            final int ch1 = bytes[offset + 0] & 0xff;
            final int ch2 = bytes[offset + 1] & 0xff;
            final int ch3 = bytes[offset + 2] & 0xff;
            final int ch4 = bytes[offset + 3] & 0xff;
            return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0));
        }

        @Override
        public boolean equals(Object o) {
            return false;
        }

        @Override
        public int hashCode() {
            return 0;
        }

        @Override
        public String description() {
            return null;
        }
        
    }
    
    /**
     * In theory one would need one direct buffer pair per segment
     * @author Christian_2
     *
     */
    private static class DirectBufferBasedGeometricDistance extends ValueSource {
        
        final FloatBuffer xbuf;
        final FloatBuffer ybuf;
        private final boolean withoutSquareRoot;

        public DirectBufferBasedGeometricDistance(final FloatBuffer x, final FloatBuffer y, boolean withoutSquareRoot) {
            this.xbuf = x;
            this.ybuf = y;
            this.withoutSquareRoot = withoutSquareRoot;
        }

        @Override
        public FunctionValues getValues(Map context,
                AtomicReaderContext readerContext) throws IOException {
            if (withoutSquareRoot) {
                return new FunctionValues() {
                    
                    @Override
                    public float floatVal(int doc) {
                        final float y = ybuf.get(doc);
                        final float x = xbuf.get(doc);
                        
                        return (float) (y*y + x*x);
                    }

                    @Override
                    public String toString(int doc) {
                        return "geometric length function value source";
                    }
                };
            }
            return new FunctionValues() {
                
                @Override
                public float floatVal(int doc) {
                    final float y = ybuf.get(doc);
                    final float x = xbuf.get(doc);
                    
                    return (float) Math.sqrt(y*y + x*x);
                }

                @Override
                public String toString(int doc) {
                    return "geometric length function value source";
                }
            };
        }

        @Override
        public boolean equals(Object o) {
            return false;
        }

        @Override
        public int hashCode() {
            return 0;
        }

        @Override
        public String description() {
            return "";
        }
    }
    
    public static class Counter extends DocumentStoredFieldVisitor {
        
        public float res = 0.0f;

        @Override
        public void floatField(FieldInfo fieldInfo, float value) {
            res += value;
        }
    };

    public static void main(String[] args) throws IOException {
        ByteBuffer dbb = ByteBuffer.allocateDirect(100000000);
        ByteBuffer dbb2 = ByteBuffer.allocateDirect(100000000);
        dbb.order(ByteOrder.nativeOrder());
        FloatBuffer dfb = dbb.asFloatBuffer();
        FloatBuffer dfb2 = dbb2.asFloatBuffer();
        Random rnd = new Random();
        final int size = dfb.limit();
        for (int i=0; i<size; i++) {
            float f = rnd.nextFloat();
            dfb.put(f);
            f = rnd.nextFloat();
            dfb2.put(f);
        }
        
        IndexWriterConfig conf = new IndexWriterConfig(Version.LUCENE_42, new Analyzer(){

            @Override
            protected TokenStreamComponents createComponents(String fieldName,
                    Reader reader) {
                Tokenizer tokenizer = new ICUTokenizer(reader);
                return new TokenStreamComponents(tokenizer);
            }
            
        });
        RAMDirectory dir = new RAMDirectory();
        /*File f = new File("temp");
        if (f.exists()) {
            for (File child : f.listFiles()) {
                child.delete();
            }
        } else {
            f.mkdir();
        }
        FSDirectory dir = FSDirectory.open(f);*/
        IndexWriter iw = new IndexWriter(dir, conf);
        BytesRef bytesRef = new BytesRef(new byte[8]);
        ByteBuffer bb = ByteBuffer.wrap(bytesRef.bytes);
        
        final float intFactor = 100000000f;
        for (int i=0; i<size; i++) {
            Document doc = new Document();
            
            bb.putFloat(0, dfb.get(i));
            bb.putFloat(4, dfb2.get(i));
            //System.out.println("x: "+dfb.get(i)+", y: "+dfb2.get(i));
            doc.add(new BinaryDocValuesField("xy", bytesRef));
            doc.add(new NumericDocValuesField("xInt", (long) (dfb.get(i) * intFactor)));
            doc.add(new NumericDocValuesField("yInt", (long) (dfb2.get(i) * intFactor)));
            iw.addDocument(doc);
            if (i % 1000000 == 0) {
                System.out.println("Indexed "+i);
            }
        }
        iw.commit();
        iw.forceMerge(1);
        iw.close();
        
        System.out.println("now reading numbers");
        DirectoryReader dr = DirectoryReader.open(dir);
        
        IndexSearcher searcher = new IndexSearcher(dr);

        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new DirectBufferBasedGeometricDistance(dfb, dfb2, true)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with direct float buffers (without square root) took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score);
        }
        
        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new DirectBufferBasedGeometricDistance(dfb, dfb2, false)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with direct float buffers (and a square root) took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score);
        }
        
        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new LuceneNumericFieldBasedGeometricDistance("xInt", "yInt", true)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with a lucene numeric values (without square root) source took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score / (intFactor*intFactor));
        }
        
        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new LuceneNumericFieldBasedGeometricDistance("xInt", "yInt", false)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with a lucene numeric values (and a square root) source took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score / intFactor);
        }
        
        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new LuceneBinaryFieldGeometricDistance("xy", true)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with a lucene binary values (without square root) source took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score);
        }

        for (int i=0; i<5; i++) {
            long start = System.currentTimeMillis();
            TopDocs topDocs = searcher.search(new FunctionQuery(new LuceneBinaryFieldGeometricDistance("xy", false)), 1);
            long end = System.currentTimeMillis();
            System.out.println("Scoring "+topDocs.totalHits+" documents with a lucene binary values (and square root) source took "+(end-start));
            System.out.println("Top doc: "+topDocs.scoreDocs[0].score);
        }
    }
    
}