package flinkdemo.issues.main;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.runtime.operators.CheckpointCommitter;
import org.apache.flink.streaming.runtime.operators.GenericWriteAheadSink;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

public class SlowBulkCopySinkMain {
    public static void main(String[] args) {
        try {
            SlowBulkCopySink engine = new SlowBulkCopySink();
            engine.run();
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(-1);
        }
    }

    public static class SlowBulkCopySink {
        private final static Logger logger = LoggerFactory.getLogger(SlowBulkCopySink.class);

        public void run() throws Exception {
            logger.info("Starting Bulk Copy Test...");

            Configuration conf = new Configuration();
            conf.setBoolean(ConfigConstants.LOCAL_START_WEBSERVER, true);
            final StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironmentWithWebUI(conf);
            env.setParallelism(1);

            env.enableCheckpointing(30 * 1000);
            env.getCheckpointConfig().setMinPauseBetweenCheckpoints(60 * 1000);
            env.getCheckpointConfig().setCheckpointTimeout(3600 * 1000);
            env.setStateBackend(new HashMapStateBackend());
            env.getCheckpointConfig().setCheckpointStorage("file:/tmp/Flink/State");

            DataStream<DataToPersist> dataToPersitStream = env.addSource(new DataToPersistSource());
            dataToPersitStream.transform("SqlServerSink", TypeInformation.of(DataToPersist.class),
                            new SqlServerBulkCopySink(
                                    new SimpleCheckpointCommitter(),
                                    TypeExtractor.createTypeInfo(DataToPersist.class).createSerializer(new ExecutionConfig()),
                                    UUID.randomUUID().toString()))
                    .setParallelism(4); // bulk copy code can support multiple threads all inserting at the same time

            env.execute("BCP Test");
        }
    }

    private static class DataToPersistSource implements SourceFunction<DataToPersist> {
        private final static Logger logger = LoggerFactory.getLogger(DataToPersistSource.class);
        protected volatile boolean isRunning = true;

        @Override
        public void run(SourceContext<DataToPersist> sourceContext) {
            int key = 0;
            long timestamp = 10000;
            while (sleep()) {
                final int itemsToEmit = 10000;
                // emit a load of data
                for (int i = 0; i < itemsToEmit; i++) {
                    String iStr = Integer.toString(i);
                    sourceContext.collectWithTimestamp(
                            new DataToPersist(
                                    key++, 1000000 + i, "entry" + iStr, 5, 10101, "tradeId" + i, "Bonds",
                                    "Buy", "Bond", "orderId" + i, 5050, 12.3, false, 10000020, 10000050),
                            timestamp++);
                }
                logger.info("Emitted {} items.", itemsToEmit);
            }
        }

        private boolean sleep() {
            try {
                Thread.sleep(500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            return isRunning;
        }

        @Override
        public void cancel() {
            isRunning = false;
        }
    }

    private static class SqlServerBulkCopySink extends GenericWriteAheadSink<DataToPersist> {
        private final static Logger logger = LoggerFactory.getLogger(SqlServerBulkCopySink.class);

        public SqlServerBulkCopySink(CheckpointCommitter committer, TypeSerializer<DataToPersist> serializer, String jobID) throws Exception {
            super(committer, serializer, jobID);
        }

        @Override
        protected boolean sendValues(Iterable<DataToPersist> tradesToPersist, long checkpointId, long timestamp) {
            int subTaskId = this.getRuntimeContext().getIndexOfThisSubtask();
            logger.info("Sending {},{},{}-----------------------------------------------", checkpointId, timestamp, subTaskId);

            // iterate over the Flink Iterable and use the latest versions of all the data provided for bulk copying into the DB
            // NOTE: 2 issues here:
            //  - this is **** EXTREMELY **** slow! it appears to be taken a long time to deserialize the objects but not sure
            //  - the iteration reuses 2 Java objects so I am forced to clone the object in the below loop before the next iteration - why?
            Map<Integer, DataToPersist> latestVersions = new HashMap<>();
            for (DataToPersist tempItem : tradesToPersist) {
                DataToPersist clonedItem = tempItem.deepClone(); // we have to fully clone the object due to iterator reusing the handle
                latestVersions.put(clonedItem.key, clonedItem);
            }

            // we would do the actual bulk copying into the SQL Server tables here but that is actually pretty quick
            // ....
            // SQLServerBulkCopy sqlServerBulkCopy = ...
            // .... using latestVersions map from above
            // .... etc

            logger.info("Persisted {} items to the DB.========================================================", latestVersions.size());
            return true;
        }
    }

    // fake CheckpointCommitter that just keeps track of what has been committed in memory
    public static class SimpleCheckpointCommitter extends CheckpointCommitter {
        private final Set<Tuple2<Integer, Long>> committedItems = new HashSet<>();

        @Override
        public void open() throws Exception {
        }

        @Override
        public void close() throws Exception {
        }

        @Override
        public void createResource() throws Exception {
        }

        @Override
        public void commitCheckpoint(int subtaskIdx, long checkpointID) throws Exception {
            committedItems.add(new Tuple2<>(subtaskIdx, checkpointID));
        }

        @Override
        public boolean isCheckpointCommitted(int subtaskIdx, long checkpointID) throws Exception {
            return committedItems.contains(new Tuple2<>(subtaskIdx, checkpointID));
        }
    }

    public static class DataToPersist {
        public int key;
        public long businessDateMS;
        public String entryId;
        public int fundId;
        public long assetId;
        public String tradeId;
        public String sectorName;
        public String sideStr;
        public String assetTypeName;
        public String orderId;
        public long quantity;
        public double price;
        public boolean isTwilight;
        public long fromDateMS;
        public long toDateMs;

        @SuppressWarnings("unused")
        public DataToPersist() {
            // public ctor for Flink POJO rule
        }

        public DataToPersist(int key, long businessDateMS, String entryId, int fundId, long assetId, String tradeId, String sectorName, String sideStr, String assetTypeName, String orderId, long quantity, double price, boolean isTwilight, long fromDateMS, long toDateMs) {
            this.key = key;
            this.businessDateMS = businessDateMS;
            this.entryId = entryId;
            this.fundId = fundId;
            this.assetId = assetId;
            this.tradeId = Objects.requireNonNull(tradeId);
            this.sectorName = Objects.requireNonNull(sectorName);
            this.sideStr = Objects.requireNonNull(sideStr);
            this.assetTypeName = Objects.requireNonNull(assetTypeName);
            this.orderId = orderId;
            this.quantity = quantity;
            this.price = price;
            this.isTwilight = isTwilight;
            this.fromDateMS = fromDateMS;
            this.toDateMs = toDateMs;
        }

        public DataToPersist deepClone() {
            return new DataToPersist(key, businessDateMS, entryId, fundId, assetId, tradeId, sectorName, sideStr, assetTypeName, orderId, quantity, price, isTwilight, fromDateMS, toDateMs);
        }

        @Override
        public String toString() {
            return "DataToPersist{" +
                    "key=" + key +
                    ", businessDateMS=" + businessDateMS +
                    ", entryId='" + entryId + '\'' +
                    ", fundId=" + fundId +
                    ", assetId=" + assetId +
                    ", tradeId='" + tradeId + '\'' +
                    ", sectorName='" + sectorName + '\'' +
                    ", sideStr='" + sideStr + '\'' +
                    ", assetTypeName='" + assetTypeName + '\'' +
                    ", orderId='" + orderId + '\'' +
                    ", quantity=" + quantity +
                    ", price=" + price +
                    ", isTwilight=" + isTwilight +
                    ", fromDateMS=" + fromDateMS +
                    ", toDateMs=" + toDateMs +
                    '}';
        }
    }
}
