diff --git a/eos/core/tests.py b/eos/core/tests.py index 09a9939..735b6a1 100644 --- a/eos/core/tests.py +++ b/eos/core/tests.py @@ -25,6 +25,13 @@ class EosTestCase: def setUpClass(cls): pass + def assertTrue(self, a): + if is_python: + self.impl.assertTrue(a) + else: + if not a: + raise Error('Assertion failed: ' + str(a) + ' not True') + def assertEqual(self, a, b): if is_python: self.impl.assertEqual(a, b) diff --git a/eos/psgjjr/bitstream.py b/eos/psgjjr/bitstream.py index 3f7f03c..7983eed 100644 --- a/eos/psgjjr/bitstream.py +++ b/eos/psgjjr/bitstream.py @@ -32,21 +32,53 @@ class BitStream(EosObject): self.ptr = ptr self.remaining = self.nbits - self.ptr - def read(self, nbits): + def read(self, nbits=None): # 11000110110 # ^---- + if nbits is None: + nbits = self.remaining + if nbits > self.remaining: + nbits = self.remaining + val = (self.impl >> (self.remaining - nbits)) & ((ONE << nbits) - ONE) self.ptr += nbits self.remaining -= nbits return val - def write(self, bits): + def write(self, bits, nbits=None): # 11 0100110 # 10010 # ^---- - self.impl = ((self.impl >> self.remaining) << (self.remaining + bits.nbits())) | (bits << self.remaining) | (self.impl & ((ONE << self.remaining) - 1)) - self.ptr += bits.nbits() - self.nbits += bits.nbits() + if nbits is None: + nbits = bits.nbits() + + self.impl = ((self.impl >> self.remaining) << (self.remaining + nbits)) | (bits << self.remaining) | (self.impl & ((ONE << self.remaining) - 1)) + self.ptr += nbits + self.nbits += nbits + + # Make the size of this BitStream a multiple of the block_size + def multiple_of(self, block_size): + if self.nbits % block_size != 0: + self.nbits += (block_size - (self.nbits % block_size)) + return self # For convenient chaining + + def map(self, func, block_size): + if self.nbits % block_size != 0: + raise Exception('The size of the BitStream must be a multiple of block_size') + + self.seek(0) + result = [] + while self.remaining > 0: + result.append(func(self.read(block_size))) + return result + + @classmethod + def unmap(cls, value, func, block_size): + bs = cls() + for x in value: + bs.write(func(x), block_size) + bs.seek(0) + return bs def serialise(self): return self.impl diff --git a/eos/psgjjr/crypto.py b/eos/psgjjr/crypto.py index 2bdf8ff..ab962bb 100644 --- a/eos/psgjjr/crypto.py +++ b/eos/psgjjr/crypto.py @@ -56,13 +56,13 @@ class EGPrivateKey(EmbeddedObject): # HAC 8.17 @staticmethod - def generate(): + def generate(group=DEFAULT_GROUP): # Choose an element 1 <= x <= p - 2 - x = BigInt.crypto_random(ONE, DEFAULT_GROUP.p - TWO) + x = BigInt.crypto_random(ONE, group.p - TWO) # Calculate the public key as G^x - X = pow(DEFAULT_GROUP.g, x, DEFAULT_GROUP.p) + X = pow(group.g, x, group.p) - pk = EGPublicKey(group=DEFAULT_GROUP, X=X) + pk = EGPublicKey(group=group, X=X) sk = EGPrivateKey(public_key=pk, x=x) return sk diff --git a/eos/psgjjr/tests.py b/eos/psgjjr/tests.py index 29950a1..4e06f83 100644 --- a/eos/psgjjr/tests.py +++ b/eos/psgjjr/tests.py @@ -27,6 +27,17 @@ class EGTestCase(EosTestCase): ct = sk.public_key.encrypt(pt) m = sk.decrypt(ct) self.assertEqualJSON(pt, m) + + def test_eg_block(self): + test_group = CyclicGroup(p=BigInt('11'), g=BigInt('2')) + pt = BigInt('11010010011111010100101', 2) + sk = EGPrivateKey.generate(test_group) + ct = BitStream(pt).multiple_of(test_group.p.nbits() - 1).map(sk.public_key.encrypt, test_group.p.nbits() - 1) + for i in range(len(ct)): + self.assertTrue(ct[i].gamma < test_group.p) + self.assertTrue(ct[i].delta < test_group.p) + m = BitStream.unmap(ct, sk.decrypt, test_group.p.nbits() - 1).read() + self.assertEqualJSON(pt, m) class BitStreamTestCase(EosTestCase): def test_bitstream(self): @@ -47,3 +58,10 @@ class BitStreamTestCase(EosTestCase): self.assertEqual(bs.read(4), 0b1101) self.assertEqual(bs.read(4), 0b0110) self.assertEqual(bs.read(2), 0b11) + + def test_bitstream_map(self): + bs = BitStream(BigInt('100101011011', 2)) + result = bs.map(lambda x: x, 4) + expect = [0b1001, 0b0101, 0b1011] + for i in range(len(expect)): + self.assertEqual(result[i], expect[i])