Repository: incubator-pirk Updated Branches: refs/heads/master bc740a79b -> 8a1c4e153
Convert DecryptResponseRunnable to a Callable, this closes apache/incubator-pirk#68 Project: http://git-wip-us.apache.org/repos/asf/incubator-pirk/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-pirk/commit/8a1c4e15 Tree: http://git-wip-us.apache.org/repos/asf/incubator-pirk/tree/8a1c4e15 Diff: http://git-wip-us.apache.org/repos/asf/incubator-pirk/diff/8a1c4e15 Branch: refs/heads/master Commit: 8a1c4e153b3965e9f6d1c9e55428bb2cda4a9c9c Parents: bc740a7 Author: Tim Ellison <t.p.elli...@gmail.com> Authored: Thu Aug 18 10:51:35 2016 -0400 Committer: smarthi <smar...@apache.org> Committed: Thu Aug 18 10:51:35 2016 -0400 ---------------------------------------------------------------------- .../wideskies/decrypt/DecryptResponse.java | 47 ++--- .../decrypt/DecryptResponseRunnable.java | 172 ------------------- .../wideskies/decrypt/DecryptResponseTask.java | 152 ++++++++++++++++ 3 files changed, 176 insertions(+), 195 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-pirk/blob/8a1c4e15/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java index 2231160..97b93fd 100644 --- a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java +++ b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponse.java @@ -29,9 +29,12 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.TreeMap; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.apache.pirk.encryption.Paillier; import org.apache.pirk.querier.wideskies.Querier; @@ -48,6 +51,8 @@ import org.slf4j.LoggerFactory; public class DecryptResponse { private static final Logger logger = LoggerFactory.getLogger(DecryptResponse.class); + + private static final BigInteger TWO_BI = BigInteger.valueOf(2); private final Response response; @@ -87,25 +92,24 @@ public class DecryptResponse Paillier paillier = querier.getPaillier(); List<String> selectors = querier.getSelectors(); - HashMap<Integer,String> embedSelectorMap = querier.getEmbedSelectorMap(); + Map<Integer,String> embedSelectorMap = querier.getEmbedSelectorMap(); // Perform decryption on the encrypted columns - ArrayList<BigInteger> rElements = decryptElements(response.getResponseElements(), paillier); + List<BigInteger> rElements = decryptElements(response.getResponseElements(), paillier); logger.debug("rElements.size() = " + rElements.size()); // Pull the necessary parameters int dataPartitionBitSize = queryInfo.getDataPartitionBitSize(); // Initialize the result map and masks-- removes initialization checks from code below - HashMap<String,BigInteger> selectorMaskMap = new HashMap<>(); + Map<String,BigInteger> selectorMaskMap = new HashMap<>(); int selectorNum = 0; - BigInteger twoBI = BigInteger.valueOf(2); for (String selector : selectors) { resultMap.put(selector, new ArrayList<>()); // 2^{selectorNum*dataPartitionBitSize}(2^{dataPartitionBitSize} - 1) - BigInteger mask = twoBI.pow(selectorNum * dataPartitionBitSize).multiply((twoBI.pow(dataPartitionBitSize).subtract(BigInteger.ONE))); + BigInteger mask = TWO_BI.pow(selectorNum * dataPartitionBitSize).multiply((TWO_BI.pow(dataPartitionBitSize).subtract(BigInteger.ONE))); logger.debug("selector = " + selector + " mask = " + mask.toString(2)); selectorMaskMap.put(selector, mask); @@ -120,7 +124,7 @@ public class DecryptResponse } int elementsPerThread = selectors.size() / numThreads; // Integral division. - ArrayList<DecryptResponseRunnable> runnables = new ArrayList<>(); + List<Future<Map<String,List<QueryResponseJSON>>>> futures = new ArrayList<>(); for (int i = 0; i < numThreads; ++i) { // Grab the range of the thread and create the corresponding partition of selectors @@ -137,33 +141,30 @@ public class DecryptResponse } // Create the runnable and execute - // selectorMaskMap and rElements are synchronized, pirWatchlist is copied, selectors is partitioned - DecryptResponseRunnable runDec = new DecryptResponseRunnable(rElements, selectorsPartition, selectorMaskMap, queryInfo.clone(), embedSelectorMap); - runnables.add(runDec); - es.execute(runDec); - } - - // Allow threads to complete - es.shutdown(); // previously submitted tasks are executed, but no new tasks will be accepted - boolean finished = es.awaitTermination(1, TimeUnit.DAYS); // waits until all tasks complete or until the specified timeout - - if (!finished) - { - throw new PIRException("Decryption threads did not finish in the alloted time"); + DecryptResponseRunnable<Map<String,List<QueryResponseJSON>>> runDec = new DecryptResponseRunnable<>(rElements, selectorsPartition, selectorMaskMap, queryInfo.clone(), embedSelectorMap); + futures.add(es.submit(runDec)); } // Pull all decrypted elements and add to resultMap - for (DecryptResponseRunnable runner : runnables) + try + { + for (Future<Map<String,List<QueryResponseJSON>>> future : futures) + { + resultMap.putAll(future.get(1, TimeUnit.DAYS)); + } + } catch (TimeoutException | ExecutionException e) { - resultMap.putAll(runner.getResultMap()); + throw new PIRException("Exception in decryption threads.", e); } + + es.shutdown(); } // Method to perform basic decryption of each raw response element - does not // extract and reconstruct the data elements - private ArrayList<BigInteger> decryptElements(TreeMap<Integer,BigInteger> elements, Paillier paillier) + private List<BigInteger> decryptElements(TreeMap<Integer,BigInteger> elements, Paillier paillier) { - ArrayList<BigInteger> decryptedElements = new ArrayList<>(); + List<BigInteger> decryptedElements = new ArrayList<>(); for (BigInteger encElement : elements.values()) { http://git-wip-us.apache.org/repos/asf/incubator-pirk/blob/8a1c4e15/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java deleted file mode 100644 index 531ec6a..0000000 --- a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseRunnable.java +++ /dev/null @@ -1,172 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pirk.querier.wideskies.decrypt; - -import java.math.BigInteger; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -import org.apache.pirk.query.wideskies.QueryInfo; -import org.apache.pirk.query.wideskies.QueryUtils; -import org.apache.pirk.schema.query.QuerySchema; -import org.apache.pirk.schema.query.QuerySchemaRegistry; -import org.apache.pirk.schema.response.QueryResponseJSON; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Runnable class for multithreaded PIR decryption - * <p> - * NOTE: rElements and selectorMaskMap are joint access objects, for now - * - */ -public class DecryptResponseRunnable implements Runnable -{ - private static final Logger logger = LoggerFactory.getLogger(DecryptResponseRunnable.class); - - private final Map<String,List<QueryResponseJSON>> resultMap = new HashMap<>(); // selector -> ArrayList of hits - private final List<BigInteger> rElements; - private final TreeMap<Integer,String> selectors; - private final Map<String,BigInteger> selectorMaskMap; - private final QueryInfo queryInfo; - - private final Map<Integer,String> embedSelectorMap; - - public DecryptResponseRunnable(List<BigInteger> rElementsInput, TreeMap<Integer,String> selectorsInput, Map<String,BigInteger> selectorMaskMapInput, - QueryInfo queryInfoInput, Map<Integer,String> embedSelectorMapInput) - { - rElements = rElementsInput; - selectors = selectorsInput; - selectorMaskMap = selectorMaskMapInput; - queryInfo = queryInfoInput; - embedSelectorMap = embedSelectorMapInput; - } - - public Map<String,List<QueryResponseJSON>> getResultMap() - { - return resultMap; - } - - @Override - public void run() - { - // Pull the necessary parameters - int dataPartitionBitSize = queryInfo.getDataPartitionBitSize(); - int numPartitionsPerDataElement = queryInfo.getNumPartitionsPerDataElement(); - - QuerySchema qSchema = QuerySchemaRegistry.get(queryInfo.getQueryType()); - String selectorName = qSchema.getSelectorName(); - - // Initialize - removes checks below - for (String selector : selectors.values()) - { - resultMap.put(selector, new ArrayList<QueryResponseJSON>()); - } - - logger.debug("numResults = " + rElements.size() + " numPartitionsPerDataElement = " + numPartitionsPerDataElement); - - // Pull the hits for each selector - int hits = 0; - int maxHitsPerSelector = rElements.size() / numPartitionsPerDataElement; // Max number of data hits in the response elements for a given selector - logger.debug("numHits = " + maxHitsPerSelector); - while (hits < maxHitsPerSelector) - { - int selectorIndex = selectors.firstKey(); - while (selectorIndex <= selectors.lastKey()) - { - String selector = selectors.get(selectorIndex); - logger.debug("selector = " + selector); - - ArrayList<BigInteger> parts = new ArrayList<>(); - int partNum = 0; - boolean zeroElement = true; - while (partNum < numPartitionsPerDataElement) - { - BigInteger part = (rElements.get(hits * numPartitionsPerDataElement + partNum)).and(selectorMaskMap.get(selector)); // pull off the correct bits - - logger.debug("rElements.get(" + (hits * numPartitionsPerDataElement + partNum) + ") = " - + rElements.get(hits * numPartitionsPerDataElement + partNum).toString(2) + " bitLength = " - + rElements.get(hits * numPartitionsPerDataElement + partNum).bitLength() + " val = " - + rElements.get(hits * numPartitionsPerDataElement + partNum)); - logger.debug("colNum = " + (hits * numPartitionsPerDataElement + partNum) + " partNum = " + partNum + " part = " + part); - - part = part.shiftRight(selectorIndex * dataPartitionBitSize); - parts.add(part); - - logger.debug("partNum = " + partNum + " part = " + part.intValue()); - - if (zeroElement) - { - if (!part.equals(BigInteger.ZERO)) - { - zeroElement = false; - } - } - ++partNum; - } - - logger.debug("parts.size() = " + parts.size()); - - if (!zeroElement) - { - // Convert biHit to the appropriate QueryResponseJSON object, based on the queryType - QueryResponseJSON qrJOSN = null; - try - { - qrJOSN = QueryUtils.extractQueryResponseJSON(queryInfo, qSchema, parts); - } catch (Exception e) - { - e.printStackTrace(); - throw new RuntimeException(e); - } - qrJOSN.setMapping(selectorName, selector); - logger.debug("selector = " + selector + " qrJOSN = " + qrJOSN.getJSONString()); - - // Add the hit for this selector - if we are using embedded selectors, check to make sure - // that the hit's embedded selector in the qrJOSN and the once in the embedSelectorMap match - boolean addHit = true; - if (queryInfo.getEmbedSelector()) - { - if (!(embedSelectorMap.get(selectorIndex)).equals(qrJOSN.getValue(QueryResponseJSON.SELECTOR))) - { - addHit = false; - logger.debug("qrJOSN embedded selector = " + qrJOSN.getValue(QueryResponseJSON.SELECTOR) + " != original embedded selector = " - + embedSelectorMap.get(selectorIndex)); - } - } - if (addHit) - { - List<QueryResponseJSON> selectorHitList = resultMap.get(selector); - selectorHitList.add(qrJOSN); - resultMap.put(selector, selectorHitList); - - // Add the selector into the wlJSONHit - qrJOSN.setMapping(QueryResponseJSON.SELECTOR, selector); - } - } - - ++selectorIndex; - } - ++hits; - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-pirk/blob/8a1c4e15/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java new file mode 100644 index 0000000..7b197d8 --- /dev/null +++ b/src/main/java/org/apache/pirk/querier/wideskies/decrypt/DecryptResponseTask.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pirk.querier.wideskies.decrypt; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Callable; + +import org.apache.pirk.query.wideskies.QueryInfo; +import org.apache.pirk.query.wideskies.QueryUtils; +import org.apache.pirk.schema.query.QuerySchema; +import org.apache.pirk.schema.query.QuerySchemaRegistry; +import org.apache.pirk.schema.response.QueryResponseJSON; +import org.apache.pirk.utils.PIRException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Runnable class for multithreaded PIR decryption + * <p> + * NOTE: rElements and selectorMaskMap are joint access objects, for now + * + */ +class DecryptResponseRunnable<V> implements Callable<Map<String,List<QueryResponseJSON>>> +{ + private static final Logger logger = LoggerFactory.getLogger(DecryptResponseRunnable.class); + + private final List<BigInteger> rElements; + private final TreeMap<Integer,String> selectors; + private final Map<String,BigInteger> selectorMaskMap; + private final QueryInfo queryInfo; + + private final Map<Integer,String> embedSelectorMap; + + public DecryptResponseRunnable(List<BigInteger> rElementsInput, TreeMap<Integer,String> selectorsInput, Map<String,BigInteger> selectorMaskMapInput, + QueryInfo queryInfoInput, Map<Integer,String> embedSelectorMapInput) + { + rElements = rElementsInput; + selectors = selectorsInput; + selectorMaskMap = selectorMaskMapInput; + queryInfo = queryInfoInput; + embedSelectorMap = embedSelectorMapInput; + } + + @Override + public Map<String,List<QueryResponseJSON>> call() throws PIRException + { + // Pull the necessary parameters + int dataPartitionBitSize = queryInfo.getDataPartitionBitSize(); + int numPartitionsPerDataElement = queryInfo.getNumPartitionsPerDataElement(); + + QuerySchema qSchema = QuerySchemaRegistry.get(queryInfo.getQueryType()); + String selectorName = qSchema.getSelectorName(); + + // Result is a map of (selector -> List of hits). + Map<String,List<QueryResponseJSON>> resultMap = new HashMap<>(); + for (String selector : selectors.values()) + { + resultMap.put(selector, new ArrayList<QueryResponseJSON>()); + } + + // Pull the hits for each selector + int maxHitsPerSelector = rElements.size() / numPartitionsPerDataElement; // Max number of data hits in the response elements for a given selector + logger.debug("numResults = " + rElements.size() + " numPartitionsPerDataElement = " + numPartitionsPerDataElement + " maxHits = " + maxHitsPerSelector); + + for (int hits = 0; hits < maxHitsPerSelector; hits++) + { + int selectorIndex = selectors.firstKey(); + while (selectorIndex <= selectors.lastKey()) + { + String selector = selectors.get(selectorIndex); + logger.debug("selector = " + selector); + + List<BigInteger> parts = new ArrayList<>(); + boolean zeroElement = true; + for (int partNum = 0; partNum < numPartitionsPerDataElement; partNum++) + { + BigInteger part = (rElements.get(hits * numPartitionsPerDataElement + partNum)).and(selectorMaskMap.get(selector)); // pull off the correct bits + + logger.debug("rElements.get(" + (hits * numPartitionsPerDataElement + partNum) + ") = " + + rElements.get(hits * numPartitionsPerDataElement + partNum).toString(2) + " bitLength = " + + rElements.get(hits * numPartitionsPerDataElement + partNum).bitLength() + " val = " + + rElements.get(hits * numPartitionsPerDataElement + partNum)); + logger.debug("colNum = " + (hits * numPartitionsPerDataElement + partNum) + " partNum = " + partNum + " part = " + part); + + part = part.shiftRight(selectorIndex * dataPartitionBitSize); + parts.add(part); + + logger.debug("partNum = " + partNum + " part = " + part.intValue()); + + zeroElement = zeroElement && part.equals(BigInteger.ZERO); + } + + logger.debug("parts.size() = " + parts.size()); + + if (!zeroElement) + { + // Convert biHit to the appropriate QueryResponseJSON object, based on the queryType + QueryResponseJSON qrJOSN = QueryUtils.extractQueryResponseJSON(queryInfo, qSchema, parts); + qrJOSN.setMapping(selectorName, selector); + logger.debug("selector = " + selector + " qrJOSN = " + qrJOSN.getJSONString()); + + // Add the hit for this selector - if we are using embedded selectors, check to make sure + // that the hit's embedded selector in the qrJOSN and the once in the embedSelectorMap match + boolean addHit = true; + if (queryInfo.getEmbedSelector()) + { + if (!(embedSelectorMap.get(selectorIndex)).equals(qrJOSN.getValue(QueryResponseJSON.SELECTOR))) + { + addHit = false; + logger.debug("qrJOSN embedded selector = " + qrJOSN.getValue(QueryResponseJSON.SELECTOR) + " != original embedded selector = " + + embedSelectorMap.get(selectorIndex)); + } + } + if (addHit) + { + List<QueryResponseJSON> selectorHitList = resultMap.get(selector); + selectorHitList.add(qrJOSN); + resultMap.put(selector, selectorHitList); + + // Add the selector into the wlJSONHit + qrJOSN.setMapping(QueryResponseJSON.SELECTOR, selector); + } + } + + ++selectorIndex; + } + } + + return resultMap; + } +}