diff --git a/pir/src/main/java/dk/au/pir/Driver.java b/pir/src/main/java/dk/au/pir/Driver.java index 3753e3a..08eef7a 100644 --- a/pir/src/main/java/dk/au/pir/Driver.java +++ b/pir/src/main/java/dk/au/pir/Driver.java @@ -1,119 +1,133 @@ package dk.au.pir; import dk.au.pir.databases.Database; -import dk.au.pir.databases.MemoryDatabase; +import dk.au.pir.databases.FakeDatabase; import dk.au.pir.profilers.Profiler; -import dk.au.pir.protocols.evenSimpler.EvenSimplerClient; -import dk.au.pir.protocols.evenSimpler.EvenSimplerServer; import dk.au.pir.protocols.interpoly.InterPolyClient; import dk.au.pir.protocols.interpoly.InterPolyServer; -import dk.au.pir.protocols.simple.SimpleClient; -import dk.au.pir.protocols.simple.SimpleServer; +import dk.au.pir.protocols.stupid.SendAllClient; +import dk.au.pir.protocols.stupid.SendAllServer; +import dk.au.pir.protocols.xor.SqrtXORClient; +import dk.au.pir.protocols.xor.SqrtXORServer; +import dk.au.pir.protocols.xor.XORClient; +import dk.au.pir.protocols.xor.XORServer; import dk.au.pir.settings.PIRSettings; +import java.util.Arrays; + public class Driver { - private static void testEvenSimplerScheme(PIRSettings settings, Database database, Profiler profiler) { - EvenSimplerServer[] servers = new EvenSimplerServer[settings.getNumServers()]; + private static int[] testSendAllScheme(PIRSettings settings, Database database, Profiler profiler) { + SendAllServer[] servers = new SendAllServer[settings.getNumServers()]; for (int i = 0; i < settings.getNumServers(); i++) { - servers[i] = new EvenSimplerServer(database, settings); + servers[i] = new SendAllServer(database, settings); } - EvenSimplerClient client = new EvenSimplerClient(settings, servers, profiler); + SendAllClient client = new SendAllClient(settings, servers, profiler); profiler.start(); - client.receiveBits(0); + int[] res = client.receive(0); profiler.stop(); + return res; } - private static void testSimpleScheme(PIRSettings settings, Database database, Profiler profiler) { - SimpleServer[] servers = new SimpleServer[settings.getNumServers()]; + private static int[] testXORScheme(PIRSettings settings, Database database, Profiler profiler) { + XORServer[] servers = new XORServer[settings.getNumServers()]; for (int i = 0; i < settings.getNumServers(); i++) { - servers[i] = new SimpleServer(database, settings); + servers[i] = new XORServer(database, settings); } - SimpleClient client = new SimpleClient(settings, servers, profiler); + XORClient client = new XORClient(settings, servers, profiler); profiler.start(); - client.receiveBit(0); + int[] res = client.receive(0); profiler.stop(); + return res; } - private static void testSimpleBlockScheme(PIRSettings settings, Database database, Profiler profiler) { - SimpleServer[] servers = new SimpleServer[settings.getNumServers()]; + private static int[] testSqrtXORScheme(PIRSettings settings, Database database, Profiler profiler) { + SqrtXORServer[] servers = new SqrtXORServer[settings.getNumServers()]; for (int i = 0; i < settings.getNumServers(); i++) { - servers[i] = new SimpleServer(database, settings); + servers[i] = new SqrtXORServer(database, settings); } - SimpleClient client = new SimpleClient(settings, servers, profiler); + SqrtXORClient client = new SqrtXORClient(settings, servers, profiler); profiler.start(); - client.receiveBits(0); + int[] res = client.receive(0); profiler.stop(); + return res; } - private static void testGeneralInterPolyScheme(PIRSettings settings, Database database, Profiler profiler) { + private static int[] testInterPolyScheme(PIRSettings settings, Database database, Profiler profiler) throws IllegalArgumentException { InterPolyServer[] servers = new InterPolyServer[settings.getNumServers()]; for (int i = 0; i < settings.getNumServers(); i++) { servers[i] = new InterPolyServer(database, settings); } InterPolyClient client = new InterPolyClient(settings, servers, profiler); profiler.start(); - client.receive(0); + int[] res = client.receive(0); profiler.stop(); + return res; } - private static void testGeneralInterPolyBlockScheme(PIRSettings settings, Database database, Profiler profiler) { - InterPolyServer[] servers = new InterPolyServer[settings.getNumServers()]; - for (int i = 0; i < settings.getNumServers(); i++) { - servers[i] = new InterPolyServer(database, settings); - } - InterPolyClient client = new InterPolyClient(settings, servers, profiler); - profiler.start(); - client.receiveBlock(0); - profiler.stop(); - } - - private static void runTests() { - for (int numServers = 1; numServers <= 16; numServers = numServers*2) { - for (int databaseSize = 2048; databaseSize <= 32_768; databaseSize = databaseSize*2) { - for (int blockSize = 64; blockSize <= 16_384; blockSize = blockSize*2) { - for (int i = 0; i < 5; i++) { - runTest(numServers, databaseSize, blockSize); - } - } - } + private static void runTests(int numServers, int databaseSize, int blockSize) { + PIRSettings settings = new PIRSettings(databaseSize, numServers, blockSize); + for (int i = 0; i < 3; i++) { // TODO: repeat x times to warm-up + runTest(numServers, databaseSize, blockSize, settings); } } - private static void runTest(int numServers, int databaseSize, int blockSize) { - PIRSettings settings = new PIRSettings(databaseSize*blockSize, numServers, blockSize); - int[] x = new int[databaseSize*blockSize]; - for (int i = 0; i < x.length; i++) { - x[i] = (int) (Math.random()*2); // 0 or 1 - } - Database database = new MemoryDatabase(settings, x); + private static void runTest(int numServers, int databaseSize, int blockSize, PIRSettings settings) { + Database database = new FakeDatabase(settings); Profiler profiler = new Profiler(); - profiler.reset(); - testEvenSimplerScheme(settings, database, profiler); - reportResult(numServers, databaseSize, blockSize, profiler, "EvenSimplerScheme"); + try { + profiler.reset(); + testSendAllScheme(settings, database, profiler); + reportResult(numServers, databaseSize, blockSize, profiler, "SendAllScheme"); + } catch (OutOfMemoryError error) { + reportFailure(numServers, databaseSize, blockSize, "oom", "SendAllScheme"); + } if (numServers == 2) { - profiler.reset(); - testSimpleScheme(settings, database, profiler); - reportResult(numServers, databaseSize, blockSize, profiler, "SimpleScheme"); + try { + profiler.reset(); + testXORScheme(settings, database, profiler); + reportResult(numServers, databaseSize, blockSize, profiler, "XORScheme"); + } catch (OutOfMemoryError error) { + reportFailure(numServers, databaseSize, blockSize, "oom", "XORScheme"); + } - profiler.reset(); - testSimpleBlockScheme(settings, database, profiler); - reportResult(numServers, databaseSize, blockSize, profiler, "SimpleBlockScheme"); + try { + profiler.reset(); + testSqrtXORScheme(settings, database, profiler); + reportResult(numServers, databaseSize, blockSize, profiler, "SqrtXORScheme"); + } catch (OutOfMemoryError error) { + reportFailure(numServers, databaseSize, blockSize, "oom", "SqrtXORScheme"); + } } - if (settings.getS() != 0 && numServers != 1) { - profiler.reset(); - testGeneralInterPolyScheme(settings, database, profiler); - reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyScheme"); + try { + boolean interPolySchemeShouldFuckOff = true; + if (numServers != 1 && !interPolySchemeShouldFuckOff) { + try { + profiler.reset(); + testInterPolyScheme(settings, database, profiler); + reportResult(numServers, databaseSize, blockSize, profiler, "InterPolyScheme"); + } catch (OutOfMemoryError error) { + reportFailure(numServers, databaseSize, blockSize, "oom", "InterPolyScheme"); + } + } + } catch (IllegalArgumentException ignored) { - profiler.reset(); - testGeneralInterPolyBlockScheme(settings, database, profiler); - reportResult(numServers, databaseSize, blockSize, profiler, "GeneralInterPolyBlockScheme"); } } + private static void reportFailure(int numServers, int databaseSize, int blockSize, String msg, String protocolName) { + System.out.println( + numServers + " " + + databaseSize + " " + + blockSize + " " + + protocolName + " " + + "error:" + msg + ); + } + private static void reportResult(int numServers, int databaseSize, int blockSize, Profiler profiler, String protocolName) { System.out.println( numServers + " " + @@ -126,7 +140,19 @@ public class Driver { ); } + private static void testSanity() { + PIRSettings settings = new PIRSettings(64, 2, 4); + Database database = new FakeDatabase(settings); + Profiler profiler = new Profiler(); + + System.out.println(Arrays.toString(testSendAllScheme(settings, database, profiler))); + System.out.println(Arrays.toString(testXORScheme(settings, database, profiler))); + System.out.println(Arrays.toString(testSqrtXORScheme(settings, database, profiler))); + System.out.println(Arrays.toString(testInterPolyScheme(settings, database, profiler))); + } + public static void main(String[] args) { - runTests(); + //testSanity(); + runTests(Integer.parseInt(args[1]), Integer.parseInt(args[2]), Integer.parseInt(args[3])); } } diff --git a/pir/src/main/java/dk/au/pir/databases/Database.java b/pir/src/main/java/dk/au/pir/databases/Database.java index 5f4f571..a12dc13 100644 --- a/pir/src/main/java/dk/au/pir/databases/Database.java +++ b/pir/src/main/java/dk/au/pir/databases/Database.java @@ -1,5 +1,5 @@ package dk.au.pir.databases; -public interface Database { - public int[] getX(); +public interface Database extends Iterable { + public int get(long index); } diff --git a/pir/src/main/java/dk/au/pir/databases/DiskDatabase.java b/pir/src/main/java/dk/au/pir/databases/DiskDatabase.java new file mode 100644 index 0000000..a5c8c43 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/databases/DiskDatabase.java @@ -0,0 +1,64 @@ +package dk.au.pir.databases; + +import dk.au.pir.settings.PIRSettings; +import dk.au.pir.utils.ArrayUtils; + +import java.io.*; +import java.util.Iterator; + +public class DiskDatabase implements Database { + private PIRSettings settings; + private String fileName; + + public DiskDatabase(PIRSettings settings, String fileName) { + this.settings = settings; + this.fileName = fileName; + } + + @Override + public int get(long index) { + return 0; + } + + private BufferedInputStream getBufferedInputStream() { + try { + File file = new File(this.fileName); + FileInputStream fileInputStream = new FileInputStream(file); + return new BufferedInputStream(fileInputStream); + } catch (FileNotFoundException ignored) { + return null; // yeah, fuck that + } + } + + @Override + public Iterator iterator() { + return new Iterator() { + int blockSize = settings.getBlocksize(); + byte[] cbuf = new byte[blockSize]; + long hasRead = 0; + BufferedInputStream bufferedInputStream = null; + + @Override + public boolean hasNext() { + return hasRead < settings.getDatabaseSize(); + } + + @Override + public Byte[] next() { + try { + if (!hasNext()) { + return null; + } + while (bufferedInputStream == null || bufferedInputStream.read(cbuf, 0, blockSize) == -1) { + System.out.println("Resetting bufferedInputStream"); + bufferedInputStream = getBufferedInputStream(); + } + hasRead += blockSize; + return ArrayUtils.toWrapper(cbuf); + } catch (IOException error) { + return null; + } + } + }; + } +} diff --git a/pir/src/main/java/dk/au/pir/databases/FakeDatabase.java b/pir/src/main/java/dk/au/pir/databases/FakeDatabase.java new file mode 100644 index 0000000..5c98424 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/databases/FakeDatabase.java @@ -0,0 +1,20 @@ +package dk.au.pir.databases; + +import dk.au.pir.settings.PIRSettings; + +import java.util.Iterator; + +public class FakeDatabase implements Database { + public FakeDatabase(PIRSettings settings) { + } + + @Override + public int get(long index) { + return (int) (index % 2); + } + + @Override + public Iterator iterator() { + return null; + } +} diff --git a/pir/src/main/java/dk/au/pir/databases/MemoryDatabase.java b/pir/src/main/java/dk/au/pir/databases/MemoryDatabase.java deleted file mode 100644 index 7340e5e..0000000 --- a/pir/src/main/java/dk/au/pir/databases/MemoryDatabase.java +++ /dev/null @@ -1,15 +0,0 @@ -package dk.au.pir.databases; - -import dk.au.pir.settings.PIRSettings; - -public class MemoryDatabase implements Database{ - private final int[] x; - - public MemoryDatabase(PIRSettings settings, int[] x) { - this.x = x; - } - - public int[] getX() { - return x; - } -} diff --git a/pir/src/main/java/dk/au/pir/profilers/Profiler.java b/pir/src/main/java/dk/au/pir/profilers/Profiler.java index cedd575..7314dc6 100644 --- a/pir/src/main/java/dk/au/pir/profilers/Profiler.java +++ b/pir/src/main/java/dk/au/pir/profilers/Profiler.java @@ -4,8 +4,8 @@ import dk.au.pir.utils.FieldElement; import dk.au.pir.utils.MathUtils; public class Profiler { - private int sent; - private int received; + private long sent; + private long received; private long startTime; private long stopTime; @@ -48,6 +48,23 @@ public class Profiler { return numbersArrays; } + public boolean clientSend(boolean bool) { + this.sent += 1; + return bool; + } + + public boolean[] clientSend(boolean[] bools) { + this.sent += bools.length; + return bools; + } + + public boolean[][] clientSend(boolean[][] boolsArray) { + for (boolean[] bools: boolsArray) { + clientSend(bools); + } + return boolsArray; + } + public FieldElement[] clientSend(FieldElement[] elements) { for (FieldElement element : elements) { this.sent += element.getValue().bitLength(); @@ -55,11 +72,15 @@ public class Profiler { return elements; } - public FieldElement[][] clientSend(FieldElement[][] elements) { - for (FieldElement[] fe : elements) { + public FieldElement[][] clientSend(FieldElement[][] elementsArray) { + for (FieldElement[] fe : elementsArray) { this.clientSend(fe); } - return elements; + return elementsArray; + } + + public void addClientReceived(long bits) { + this.received += bits; } public int clientReceive(int number) { @@ -86,11 +107,11 @@ public class Profiler { return elements; } - public int getSent() { + public long getSent() { return this.sent; } - public int getReceived() { + public long getReceived() { return this.received; } @@ -98,10 +119,10 @@ public class Profiler { return this.stopTime - this.startTime; } - public int log2(int n) { + private long log2(long n) { if (n == 0) { return 1; // technically incorrect but for the sake of profiling, a 0-bit requires 1 bit of space } - return Integer.SIZE - Integer.numberOfLeadingZeros(n); + return Long.SIZE - Long.numberOfLeadingZeros(n); } } diff --git a/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockClient.java b/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockClient.java deleted file mode 100644 index 7e11bf0..0000000 --- a/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockClient.java +++ /dev/null @@ -1,84 +0,0 @@ -package dk.au.pir.protocols.balancedBlockScheme; - -import dk.au.pir.databases.Database; -import dk.au.pir.databases.MemoryDatabase; -import dk.au.pir.profilers.Profiler; -import dk.au.pir.protocols.simple.SimpleClient; -import dk.au.pir.protocols.simple.SimpleServer; -import dk.au.pir.settings.PIRSettings; - -import java.util.Arrays; -import java.util.Random; - -public class balancedBlockClient { - - - private final PIRSettings settings; - private final balancedBlockServer[] servers; - private final int sqrtSize; - private Profiler profiler; - - public balancedBlockClient(PIRSettings settings, balancedBlockServer[] servers, Profiler profiler) { - this.settings = settings; - this.servers = servers; - this.profiler = profiler; - this.sqrtSize = (int) Math.ceil(Math.sqrt(settings.getDatabaseSize())); - } - - public int[] selectIndexes(int n) { - int[] indexes = new int[n]; - Random rand = new Random(); - for (int i=0; i < n; i++) { - indexes[i] = rand.nextInt(2); - } - return indexes; - } - - public int receiveBit(int index) { - /** - * PLAN: - * Divide n into sqrt(n) - * Compute which index we want find this within a block - * Send block - */ - - int[] S1 = selectIndexes(this.sqrtSize); - int[] S2 = S1.clone(); - - - - int impBlock = (int) Math.floor(index/this.sqrtSize); - System.out.println("ImpBlock: " + impBlock); - if (S1[index % this.sqrtSize] == 1) { - S2[index % this.sqrtSize] = 0; // Remove the index, if it's contained in S. - } else { - S2[index % this.sqrtSize] = 1; - } - - System.out.println("S1: " + Arrays.toString(S1)); - System.out.println("S2: " + Arrays.toString(S2)); - - - int[] resBit1 = this.servers[0].computeBit(S1); - int[] resBit2 = this.servers[1].computeBit(S2); - - - - return ((resBit1[impBlock] + resBit2[impBlock]) % 2); - - } - - public static void main(String[] args) { - PIRSettings settings = new PIRSettings(16, 2, 1); - balancedBlockServer[] servers = new balancedBlockServer[settings.getNumServers()]; - - Database database = new MemoryDatabase(settings, new int[] {0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0}); - - for (int i = 0; i < settings.getNumServers(); i++) { - servers[i] = new balancedBlockServer(database, settings); - } - balancedBlockClient client = new balancedBlockClient(settings, servers, null); - System.out.println(client.receiveBit(11)); - } -} - diff --git a/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockServer.java b/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockServer.java deleted file mode 100644 index a4be041..0000000 --- a/pir/src/main/java/dk/au/pir/protocols/balancedBlockScheme/balancedBlockServer.java +++ /dev/null @@ -1,54 +0,0 @@ -package dk.au.pir.protocols.balancedBlockScheme; - -import dk.au.pir.databases.Database; -import dk.au.pir.settings.PIRSettings; - -import java.util.Arrays; - -public class balancedBlockServer { - - - private final Database database; - private final PIRSettings settings; - private final int sqrtSize; - - public balancedBlockServer(Database database, PIRSettings settings) { - this.database = database; - this.settings = settings; - this.sqrtSize = (int) Math.ceil(Math.sqrt(settings.getDatabaseSize())); - } - - public int[] computeBit(int[] indexes) { - int[] db = database.getX(); - - /* - Divide n in the sqrt(n) size chunks - Get sqrt(n) size array from client, which we cycle through sqrt(n) times - We return a sqrt(n) size list of bits. One from each cycle. - */ - - int[] resList = new int[this.sqrtSize]; - - for (int i = 0; i < this.sqrtSize; i++) { - - int tmpRes = 0; - - for (int j = 0; j < this.sqrtSize; j++) { - try { - boolean test = indexes[j] == 1; - if (test) { - System.out.println("Looking at index: " + (j + (this.sqrtSize * i))); - tmpRes = (tmpRes + db[j + (this.sqrtSize * i)]) % 2; - } - - } catch (ArrayIndexOutOfBoundsException e) { - tmpRes = (tmpRes) % 2; - } - } - resList[i] = tmpRes; - } - System.out.println("ResList: " + Arrays.toString(resList)); - - return resList; - } -} diff --git a/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerClient.java b/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerClient.java deleted file mode 100644 index 2dc3cfa..0000000 --- a/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerClient.java +++ /dev/null @@ -1,28 +0,0 @@ -package dk.au.pir.protocols.evenSimpler; - -import dk.au.pir.profilers.Profiler; -import dk.au.pir.settings.PIRSettings; - -public class EvenSimplerClient { - private final PIRSettings settings; - private final EvenSimplerServer[] servers; - private final Profiler profiler; - - public EvenSimplerClient(PIRSettings settings, EvenSimplerServer[] servers, Profiler profiler) { - this.settings = settings; - this.servers = servers; - this.profiler = profiler; - } - - public int receiveBit(int index) { - int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase()); - return data[index]; - } - - public int[] receiveBits(int record) { - int[] res = new int[settings.getBlocksize()]; - int[] data = this.profiler.clientReceive(this.servers[0].giveDatabase()); - System.arraycopy(data, (record * settings.getBlocksize()), res, 0, settings.getBlocksize()); - return res; - } -} diff --git a/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerServer.java b/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerServer.java deleted file mode 100644 index 71856d0..0000000 --- a/pir/src/main/java/dk/au/pir/protocols/evenSimpler/EvenSimplerServer.java +++ /dev/null @@ -1,16 +0,0 @@ -package dk.au.pir.protocols.evenSimpler; - -import dk.au.pir.databases.Database; -import dk.au.pir.settings.PIRSettings; - -public class EvenSimplerServer { - private final Database database; - - public EvenSimplerServer(Database database, PIRSettings settings) { - this.database = database; - } - - public int[] giveDatabase() { - return this.database.getX(); // lol - } -} diff --git a/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyClient.java b/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyClient.java index 1ab9eb3..85d4d0f 100644 --- a/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyClient.java +++ b/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyClient.java @@ -5,28 +5,34 @@ import dk.au.pir.profilers.Profiler; import dk.au.pir.settings.PIRSettings; import dk.au.pir.utils.FieldElement; import dk.au.pir.utils.FieldElementLagrange; - -import java.math.BigInteger; -import java.util.Arrays; - -import static dk.au.pir.utils.ProtocolUtils.printIntArrayArray; +import dk.au.pir.utils.MathUtils; +import dk.au.pir.utils.ProtocolUtils; public class InterPolyClient { private PIRSettings settings; private InterPolyServer[] servers; private final int s; private final BigIntegerField field; - private final int[][] sequences; + private final boolean[][] sequences; private Profiler profiler; - public InterPolyClient(PIRSettings settings, InterPolyServer[] servers, Profiler profiler) { + public InterPolyClient(PIRSettings settings, InterPolyServer[] servers, Profiler profiler) throws IllegalArgumentException { this.settings = settings; this.servers = servers; - this.s = settings.getS(); this.field = settings.getField(); - this.sequences = settings.getSequences(); this.profiler = profiler; + this.s = calculateS(this.settings.getNumServers(), this.settings.getDatabaseSize() * this.settings.getBlocksize()); // TODO: Should be long-multiplication + this.sequences = ProtocolUtils.createSequences(s, this.settings.getNumServers(), this.settings.getDatabaseSize() * this.settings.getBlocksize()); + } + + private int calculateS(int k, int n) throws IllegalArgumentException { + for (int s = k-1; s <= n; s++) { + if (MathUtils.binomial(s, k-1) >= n) { + return s; + } + } + throw new IllegalArgumentException(); } private FieldElement[] getRandomFieldElements() { @@ -48,47 +54,28 @@ public class InterPolyClient { private FieldElement[] getGs(int index, int serverNumber, FieldElement[] random) { FieldElement[] gs = new FieldElement[this.s]; - int[] i = this.sequences[index]; + boolean[] i = this.sequences[index]; for (int l = 0; l < this.s; l++) { - - gs[l] = random[l].multiply(this.field.valueOf(serverNumber)).add(this.field.valueOf(i[l])); + gs[l] = random[l].multiply(this.field.valueOf(serverNumber)); + if (i[l]) { + gs[l].add(this.field.valueOf(1)); + } } return gs; - - } - - private int receiveBit(int index) { - FieldElement[] randoms = this.getRandomFieldElements(); - FieldElement[] Fs = new FieldElement[this.servers.length]; - for (int z = 0; z < this.servers.length; z++) { - Fs[z] = this.profiler.clientReceive(this.servers[z].F(this.profiler.clientSend(this.getGs(index, z+1, randoms)))); - } - FieldElement res = FieldElementLagrange.interpolate(this.field, Fs); - return res.getValue().intValue(); } public int[] receive(int record) { - int[] results = new int[settings.getBlocksize()]; - for (int i = 0; i < settings.getBlocksize(); i++) { - results[i] = this.receiveBit((settings.getBlocksize() * record) + i); - } - return results; - } - - public int[] receiveBlock(int record) { int[] results = new int[settings.getBlocksize()]; FieldElement[][] randoms = this.getRandomFieldElementsBlock(); FieldElement[][] Fs = new FieldElement[this.servers.length][settings.getBlocksize()]; - /** - * 1) Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index - */ + //Compute all the Gs for each server, s.t. the first index should be the blocksize and it should contain all the Gs for the given index for (int z = 0; z < this.servers.length; z++) { FieldElement[][] Gs = new FieldElement[settings.getBlocksize()][this.s]; for (int i = 0; i < settings.getBlocksize(); i++) { Gs[i] = this.getGs(record*settings.getBlocksize() + i, z+1, randoms[i]); } - Fs[z] = profiler.clientReceive(this.servers[z].FBlock(profiler.clientSend(Gs))); + Fs[z] = profiler.clientReceive(this.servers[z].FBlock(profiler.clientSend(Gs), this.s, this.sequences)); } for (int i = 0; i < settings.getBlocksize(); i++) { diff --git a/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyServer.java b/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyServer.java index 1474794..f48f878 100644 --- a/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyServer.java +++ b/pir/src/main/java/dk/au/pir/protocols/interpoly/InterPolyServer.java @@ -17,29 +17,29 @@ public class InterPolyServer { this.field = settings.getField(); } - public FieldElement F(FieldElement[] gs) { + private FieldElement F(FieldElement[] gs, int s, boolean[][] sequences) { FieldElement sum = this.field.valueOf(0); - for (int j = 0; j < this.settings.getDatabaseSize(); j++) { + for (int j = 0; j < this.settings.getDatabaseSize() * this.settings.getBlocksize(); j++) { // TODO: Should be long-multiplcation FieldElement product = this.field.valueOf(1); - for (int l = 0; l < this.settings.getS(); l++) { - if (this.settings.getSequences()[j][l] == 1) { + for (int l = 0; l < s; l++) { + if (sequences[j][l]) { product = product.multiply(gs[l]); //System.out.println("gs: " + gs[l]); } } - sum = sum.add(product.multiply(this.field.valueOf(this.database.getX()[j]))); + sum = sum.add(product.multiply(this.field.valueOf(this.database.get(j)))); } return sum; } - public FieldElement[] FBlock(FieldElement[][] gs) { + public FieldElement[] FBlock(FieldElement[][] gs, int s, boolean[][] sequences) { FieldElement[] sum = new FieldElement[this.settings.getBlocksize()]; for (int i = 0; i < sum.length; i++) { sum[i] = this.field.valueOf(0); } for (int i = 0; i < this.settings.getBlocksize(); i++) { - sum[i] = F(gs[i]); + sum[i] = F(gs[i], s, sequences); } return sum; } diff --git a/pir/src/main/java/dk/au/pir/protocols/simple/SimpleClient.java b/pir/src/main/java/dk/au/pir/protocols/simple/SimpleClient.java deleted file mode 100644 index c0612c2..0000000 --- a/pir/src/main/java/dk/au/pir/protocols/simple/SimpleClient.java +++ /dev/null @@ -1,72 +0,0 @@ -package dk.au.pir.protocols.simple; - -import dk.au.pir.profilers.Profiler; -import dk.au.pir.settings.PIRSettings; - -import java.util.Random; - - -public class SimpleClient { - private final PIRSettings settings; - private final SimpleServer[] servers; - private Profiler profiler; - - public SimpleClient(PIRSettings settings, SimpleServer[] servers, Profiler profiler) { - this.settings = settings; - this.servers = servers; - this.profiler = profiler; - } - - public int[] selectIndexes() { - int[] indexes = new int[settings.getDatabaseSize()]; - Random rand = new Random(); - for (int i=0; i < settings.getDatabaseSize(); i++) { - indexes[i] = rand.nextInt(2); - } - return indexes; - } - - public int receiveBit(int index) { - int[] S1 = selectIndexes(); - int[] S2 = S1.clone(); - - if (S1[index] == 1) { - S2[index] = 0; // Remove the index, if it's contained in S. - } else { - S2[index] = 1; - } - - int resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1))); - int resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2))); - - return ((resBit1 + resBit2) % 2); - } - - public int[] receiveBits(int record) { - int[] result = new int[settings.getBlocksize()]; - - int[][] S1s = new int[settings.getBlocksize()][settings.getDatabaseSize()]; - int[][] S2s = new int[settings.getBlocksize()][settings.getDatabaseSize()]; - - for (int i = 0; i < settings.getBlocksize(); i++) { - S1s[i] = selectIndexes(); - S2s[i] = S1s[i].clone(); - - if (S1s[i][(record*settings.getBlocksize())+i] == 1) { - // Remove the index, if it's contained in S. - S2s[i][(record*settings.getBlocksize())+i] = 0; - } else { - S2s[i][(record*settings.getBlocksize())+i] = 1; - } - } - - int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s))); - int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s))); - - for (int i = 0; i < settings.getBlocksize(); i++) { - result[i] = (resBit1[i] + resBit2[i]) % 2; - } - - return result; - } -} diff --git a/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllClient.java b/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllClient.java new file mode 100644 index 0000000..3407074 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllClient.java @@ -0,0 +1,27 @@ +package dk.au.pir.protocols.stupid; + +import dk.au.pir.databases.Database; +import dk.au.pir.profilers.Profiler; +import dk.au.pir.settings.PIRSettings; + +public class SendAllClient { + private final PIRSettings settings; + private final SendAllServer[] servers; + private final Profiler profiler; + + public SendAllClient(PIRSettings settings, SendAllServer[] servers, Profiler profiler) { + this.settings = settings; + this.servers = servers; + this.profiler = profiler; + } + + public int[] receive(int record) { + int[] res = new int[settings.getBlocksize()]; + Database database = this.servers[0].giveDatabase(); + this.profiler.addClientReceived((long) this.settings.getDatabaseSize() * (long) this.settings.getBlocksize()); + for (int i = 0; i < this.settings.getBlocksize(); i++) { + res[i] = database.get(record * settings.getBlocksize() + i); + } + return res; + } +} diff --git a/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllServer.java b/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllServer.java new file mode 100644 index 0000000..f60bf75 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/protocols/stupid/SendAllServer.java @@ -0,0 +1,16 @@ +package dk.au.pir.protocols.stupid; + +import dk.au.pir.databases.Database; +import dk.au.pir.settings.PIRSettings; + +public class SendAllServer { + private final Database database; + + public SendAllServer(Database database, PIRSettings settings) { + this.database = database; + } + + public Database giveDatabase() { + return this.database; // lol + } +} diff --git a/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORClient.java b/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORClient.java new file mode 100644 index 0000000..88d28f2 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORClient.java @@ -0,0 +1,58 @@ +package dk.au.pir.protocols.xor; + +import dk.au.pir.profilers.Profiler; +import dk.au.pir.settings.PIRSettings; + +import java.util.Random; + +public class SqrtXORClient { + private final PIRSettings settings; + private final SqrtXORServer[] servers; + private final int sqrtSize; + private Profiler profiler; + + public SqrtXORClient(PIRSettings settings, SqrtXORServer[] servers, Profiler profiler) { + this.settings = settings; + this.servers = servers; + this.profiler = profiler; + this.sqrtSize = (int) Math.ceil(Math.sqrt((long) settings.getDatabaseSize() * (long) settings.getBlocksize())); + } + + public boolean[] selectIndexes(int n) { + boolean[] indexes = new boolean[n]; + Random rand = new Random(); + for (int i=0; i < n; i++) { + indexes[i] = rand.nextBoolean(); + } + return indexes; + } + + public int receiveBit(int index) { + /** + * PLAN: + * Divide n into sqrt(n) + * Compute which index we want find this within a block + * Send block + */ + boolean[] S1 = selectIndexes(this.sqrtSize); + boolean[] S2 = S1.clone(); + + int impBlock = (int) Math.floor(index/this.sqrtSize); + S2[index % this.sqrtSize] = !S1[index % this.sqrtSize]; // Remove the index, if it's contained in S. + + int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBit(this.profiler.clientSend(S1))); + int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBit(this.profiler.clientSend(S2))); + + return ((resBit1[impBlock] + resBit2[impBlock]) % 2); + } + + public int[] receive(int record) { + // TODO: This is bad - should merge with above receiveBit-method to send entire array of bits at once (like the simple XORScheme) + int[] result = new int[settings.getBlocksize()]; + for (int i = 0; i < settings.getBlocksize(); i++) { + result[i] = this.receiveBit(record * this.settings.getBlocksize() + i); + } + return result; + } +} + diff --git a/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORServer.java b/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORServer.java new file mode 100644 index 0000000..cda447c --- /dev/null +++ b/pir/src/main/java/dk/au/pir/protocols/xor/SqrtXORServer.java @@ -0,0 +1,34 @@ +package dk.au.pir.protocols.xor; + +import dk.au.pir.databases.Database; +import dk.au.pir.settings.PIRSettings; + +public class SqrtXORServer { + private final Database database; + private final PIRSettings settings; + private final int sqrtSize; + + public SqrtXORServer(Database database, PIRSettings settings) { + this.database = database; + this.settings = settings; + this.sqrtSize = (int) Math.ceil(Math.sqrt((long) settings.getDatabaseSize() * (long) settings.getBlocksize())); + } + + public int[] computeBit(boolean[] indexes) { + int[] resList = new int[this.sqrtSize]; + for (int i = 0; i < this.sqrtSize; i++) { + int tmpRes = 0; + for (int j = 0; j < this.sqrtSize; j++) { + try { + if (indexes[j]) { + tmpRes = (tmpRes + this.database.get(j + (this.sqrtSize * i))) % 2; + } + } catch (ArrayIndexOutOfBoundsException ignored) { + + } + } + resList[i] = tmpRes; + } + return resList; + } +} diff --git a/pir/src/main/java/dk/au/pir/protocols/xor/XORClient.java b/pir/src/main/java/dk/au/pir/protocols/xor/XORClient.java new file mode 100644 index 0000000..835196d --- /dev/null +++ b/pir/src/main/java/dk/au/pir/protocols/xor/XORClient.java @@ -0,0 +1,51 @@ +package dk.au.pir.protocols.xor; + +import dk.au.pir.profilers.Profiler; +import dk.au.pir.settings.PIRSettings; + +import java.util.Random; + + +public class XORClient { + private final PIRSettings settings; + private final XORServer[] servers; + private Profiler profiler; + + public XORClient(PIRSettings settings, XORServer[] servers, Profiler profiler) { + this.settings = settings; + this.servers = servers; + this.profiler = profiler; + } + + private boolean[] selectIndexes() { + boolean[] indexes = new boolean[settings.getDatabaseSize() * settings.getBlocksize()]; // TODO: should be long-multiplication + Random rand = new Random(); + for (int i=0; i < settings.getDatabaseSize() * settings.getBlocksize(); i++) { + indexes[i] = rand.nextBoolean(); + } + return indexes; + } + + public int[] receive(int record) { + int[] result = new int[settings.getBlocksize()]; + + boolean[][] S1s = new boolean[settings.getBlocksize()][settings.getDatabaseSize() * settings.getBlocksize()]; // TODO: Should be long-multiplication + boolean[][] S2s = new boolean[settings.getBlocksize()][settings.getDatabaseSize() * settings.getBlocksize()]; + + for (int i = 0; i < settings.getBlocksize(); i++) { + S1s[i] = selectIndexes(); // TODO + S2s[i] = S1s[i].clone(); + + // Remove the index, if it's contained in S. + S2s[i][(record*settings.getBlocksize())+i] = !S1s[i][(record * settings.getBlocksize()) + i]; + } + + int[] resBit1 = this.profiler.clientReceive(this.servers[0].computeBits(this.profiler.clientSend(S1s))); + int[] resBit2 = this.profiler.clientReceive(this.servers[1].computeBits(this.profiler.clientSend(S2s))); + + for (int i = 0; i < settings.getBlocksize(); i++) { + result[i] = (resBit1[i] + resBit2[i]) % 2; + } + return result; + } +} diff --git a/pir/src/main/java/dk/au/pir/protocols/simple/SimpleServer.java b/pir/src/main/java/dk/au/pir/protocols/xor/XORServer.java similarity index 58% rename from pir/src/main/java/dk/au/pir/protocols/simple/SimpleServer.java rename to pir/src/main/java/dk/au/pir/protocols/xor/XORServer.java index 76b19c3..96d5475 100644 --- a/pir/src/main/java/dk/au/pir/protocols/simple/SimpleServer.java +++ b/pir/src/main/java/dk/au/pir/protocols/xor/XORServer.java @@ -1,27 +1,27 @@ -package dk.au.pir.protocols.simple; +package dk.au.pir.protocols.xor; import dk.au.pir.databases.Database; import dk.au.pir.settings.PIRSettings; -public class SimpleServer { +public class XORServer { private final Database database; private final PIRSettings settings; - public SimpleServer(Database database, PIRSettings settings) { + public XORServer(Database database, PIRSettings settings) { this.database = database; this.settings = settings; } - public int computeBit(int[] indexes) { - int res = database.getX()[indexes[0]]; + private int computeBit(boolean[] indexes) { + int res = 0; for (int i=1; i= n) { - return s; - } - } - throw new IllegalArgumentException(); - } - public int getDatabaseSize() { return databaseSize; } @@ -46,14 +26,6 @@ public class PIRSettings { return numServers; } - public int getS() { - return s; - } - - public int[][] getSequences() { - return sequences; - } - public BigIntegerField getField() { return field; } diff --git a/pir/src/main/java/dk/au/pir/utils/ArrayUtils.java b/pir/src/main/java/dk/au/pir/utils/ArrayUtils.java new file mode 100644 index 0000000..b0346c1 --- /dev/null +++ b/pir/src/main/java/dk/au/pir/utils/ArrayUtils.java @@ -0,0 +1,11 @@ +package dk.au.pir.utils; + +import java.util.Arrays; + +public class ArrayUtils { + public static Byte[] toWrapper(byte[] bytesPrim) { + Byte[] bytes = new Byte[bytesPrim.length]; + Arrays.setAll(bytes, n -> bytesPrim[n]); + return bytes; + } +} diff --git a/pir/src/main/java/dk/au/pir/utils/MathUtils.java b/pir/src/main/java/dk/au/pir/utils/MathUtils.java index 76fca30..33f17f4 100644 --- a/pir/src/main/java/dk/au/pir/utils/MathUtils.java +++ b/pir/src/main/java/dk/au/pir/utils/MathUtils.java @@ -2,10 +2,12 @@ package dk.au.pir.utils; public class MathUtils { public static int binomial(int n, int k) { - if ((n == k) || (k == 0)) { - return 1; - } else { - return binomial(n - 1, k) + binomial(n - 1, k - 1); - } + if (k > n - k) + k = n - k; + + int b = 1; + for (int i=1, m=n; i<=k; i++, m--) + b = b * m / i; + return b; } } diff --git a/pir/src/main/java/dk/au/pir/utils/ProtocolUtils.java b/pir/src/main/java/dk/au/pir/utils/ProtocolUtils.java index 7406dde..e088b90 100644 --- a/pir/src/main/java/dk/au/pir/utils/ProtocolUtils.java +++ b/pir/src/main/java/dk/au/pir/utils/ProtocolUtils.java @@ -1,43 +1,14 @@ package dk.au.pir.utils; -import java.util.*; -import java.util.stream.Collectors; - public class ProtocolUtils { - private static int[] createSequence(int s, int k) { - Random rand = new Random(); - int[] sequence = new int[s]; - int kRemaining = k - 1; - while (kRemaining != 0) { - int rand_idx = rand.nextInt(s); - if (sequence[rand_idx] == 0) { - sequence[rand_idx] = 1; - kRemaining--; + public static boolean[][] createSequences(int s, int k, int n) { + // TODO: Un-hardcode for k!=2 + boolean[][] arrays = new boolean[n][s]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < s; j++) { + arrays[i][j] = i == j; } } - return sequence; - } - - public static int[][] createSequences(int s, int k, int n) { - Set> sequences = new HashSet<>(); - while (sequences.size() < n) { - sequences.add(Arrays.stream(createSequence(s, k)).boxed().collect(Collectors.toList())); - } - List> lists = new ArrayList<>(sequences); - lists.sort((l1, l2) -> { - for (int i = 0; i < l1.size(); i++) { - int equals = l1.get(i).compareTo(l2.get(i)); - if (equals != 0) { - return equals; - } - } - return 0; - }); - int[][] arrays = new int[n][s]; - for (int j = 0; j < n; j++) { - int[] array = lists.get(j).stream().mapToInt(i -> i).toArray(); - arrays[j] = array; - } return arrays; } @@ -48,6 +19,5 @@ public class ProtocolUtils { } System.out.println(""); } - } } diff --git a/pir/test.sh b/pir/test.sh new file mode 100755 index 0000000..b59016f --- /dev/null +++ b/pir/test.sh @@ -0,0 +1,8 @@ +apt update +apt install -y htop tmux openjdk-11-jdk +rm -f ~/results.log +cd classes/ + +tmux \ + new-session 'python3 ../collect.py | tee ~/results.log' \; \ + split-window -h 'htop' \;