diff --git a/eos/core/bigint/js.py b/eos/core/bigint/js.py index 9bcd8d8..c436a44 100644 --- a/eos/core/bigint/js.py +++ b/eos/core/bigint/js.py @@ -91,6 +91,10 @@ class BigInt(EosObject): def __str__(self): return str(self.impl) + def __int__(self): + # WARNING: This will yield unexpected results for large numbers + return int(str(self.impl)) + def __pow__(self, other, modulo=None): if not isinstance(other, BigInt): other = BigInt(other) diff --git a/eos/core/bigint/python.py b/eos/core/bigint/python.py index cc91b23..88dcae4 100644 --- a/eos/core/bigint/python.py +++ b/eos/core/bigint/python.py @@ -37,7 +37,7 @@ class BigInt(EosObject): return BigInt(self.impl.__pow__(other.impl, modulo.impl)) def nbits(self): - return math.ceil(math.log2(self.impl)) + return math.ceil(math.log2(self.impl)) if self.impl > 0 else 0 def serialise(self): return str(self) diff --git a/eos/psgjjr/bitstream.py b/eos/psgjjr/bitstream.py index 7983eed..9035062 100644 --- a/eos/psgjjr/bitstream.py +++ b/eos/psgjjr/bitstream.py @@ -38,7 +38,7 @@ class BitStream(EosObject): if nbits is None: nbits = self.remaining if nbits > self.remaining: - nbits = self.remaining + raise Exception('Not enough bits to read from BitString') val = (self.impl >> (self.remaining - nbits)) & ((ONE << nbits) - ONE) self.ptr += nbits @@ -51,15 +51,53 @@ class BitStream(EosObject): # ^---- if nbits is None: nbits = bits.nbits() + if nbits < bits.nbits(): + raise Exception('Too many bits to write to BitString') self.impl = ((self.impl >> self.remaining) << (self.remaining + nbits)) | (bits << self.remaining) | (self.impl & ((ONE << self.remaining) - 1)) self.ptr += nbits self.nbits += nbits + def read_string(self): + length = self.read(32) + length = length.__int__() # JS attempts to call this twice if we do it in one line + + if is_python: + ba = bytearray() + for i in range(length): + ba.append(int(self.read(7))) + return ba.decode('ascii') + else: + ba = [] + for i in range(length): + val = self.read(7) + val = val.__int__() + ba.append(val) + return String.fromCharCode(*ba) + + def write_string(self, strg): + self.write(BigInt(len(strg)), 32) # TODO: Arbitrary lengths + + # TODO: Support non-ASCII encodings + if is_python: + ba = strg.encode('ascii') + for i in range(len(strg)): + self.write(BigInt(ba[i]), 7) + else: + for i in range(len(strg)): + self.write(BigInt(strg.charCodeAt(i)), 7) + # Make the size of this BitStream a multiple of the block_size - def multiple_of(self, block_size): + def multiple_of(self, block_size, pad_at_end=False): if self.nbits % block_size != 0: - self.nbits += (block_size - (self.nbits % block_size)) + diff = block_size - (self.nbits % block_size) + if pad_at_end: + # Suitable for structured data + self.seek(self.nbits) + self.write(ZERO, diff) + else: + # Suitable for raw numbers + self.nbits += diff return self # For convenient chaining def map(self, func, block_size): diff --git a/eos/psgjjr/crypto.py b/eos/psgjjr/crypto.py index ab962bb..c3f4e24 100644 --- a/eos/psgjjr/crypto.py +++ b/eos/psgjjr/crypto.py @@ -42,6 +42,13 @@ class EGPublicKey(EmbeddedObject): # HAC 8.18 def encrypt(self, message): + message += ONE # Dodgy hack to allow zeroes + + if message <= ZERO: + raise Exception('Invalid message') + if message >= self.group.p: + raise Exception('Invalid message') + # Choose an element 1 <= k <= p - 2 k = BigInt.crypto_random(ONE, self.group.p - TWO) @@ -76,7 +83,8 @@ class EGPrivateKey(EmbeddedObject): gamma_inv = pow(ciphertext.gamma, self.public_key.group.p - ONE - self.x, self.public_key.group.p) - return (gamma_inv * ciphertext.delta) % self.public_key.group.p + pt = (gamma_inv * ciphertext.delta) % self.public_key.group.p + return pt - ONE class EGCiphertext(EmbeddedObject): public_key = EmbeddedObjectField(EGPublicKey) diff --git a/eos/psgjjr/tests.py b/eos/psgjjr/tests.py index 4e06f83..6fb43d6 100644 --- a/eos/psgjjr/tests.py +++ b/eos/psgjjr/tests.py @@ -27,17 +27,6 @@ class EGTestCase(EosTestCase): ct = sk.public_key.encrypt(pt) m = sk.decrypt(ct) self.assertEqualJSON(pt, m) - - def test_eg_block(self): - test_group = CyclicGroup(p=BigInt('11'), g=BigInt('2')) - pt = BigInt('11010010011111010100101', 2) - sk = EGPrivateKey.generate(test_group) - ct = BitStream(pt).multiple_of(test_group.p.nbits() - 1).map(sk.public_key.encrypt, test_group.p.nbits() - 1) - for i in range(len(ct)): - self.assertTrue(ct[i].gamma < test_group.p) - self.assertTrue(ct[i].delta < test_group.p) - m = BitStream.unmap(ct, sk.decrypt, test_group.p.nbits() - 1).read() - self.assertEqualJSON(pt, m) class BitStreamTestCase(EosTestCase): def test_bitstream(self): @@ -65,3 +54,47 @@ class BitStreamTestCase(EosTestCase): expect = [0b1001, 0b0101, 0b1011] for i in range(len(expect)): self.assertEqual(result[i], expect[i]) + + def test_strings(self): + bs = BitStream() + bs.write_string('Hello World!') + bs.seek(0) + self.assertEqual(bs.read(32), len('Hello World!')) + bs.seek(0) + self.assertEqual(bs.read_string(), 'Hello World!') + +class BlockEGTestCase(EosTestCase): + @classmethod + def setUpClass(cls): + class Person(TopLevelObject): + name = StringField() + address = StringField(default=None) + def say_hi(self): + return 'Hello! My name is ' + self.name + + cls.Person = Person + + #cls.test_group = CyclicGroup(p=BigInt('11'), g=BigInt('2')) + cls.test_group = CyclicGroup(p=BigInt('283'), g=BigInt('60')) + cls.sk = EGPrivateKey.generate(cls.test_group) + + def test_basic(self): + pt = BigInt('11010010011111010100101', 2) + ct = BitStream(pt).multiple_of(self.test_group.p.nbits() - 1).map(self.sk.public_key.encrypt, self.test_group.p.nbits() - 1) + for i in range(len(ct)): + self.assertTrue(ct[i].gamma < self.test_group.p) + self.assertTrue(ct[i].delta < self.test_group.p) + m = BitStream.unmap(ct, self.sk.decrypt, self.test_group.p.nbits() - 1).read() + self.assertEqualJSON(pt, m) + + def test_object(self): + obj = self.Person(name='John Smith') + pt = EosObject.to_json(EosObject.serialise_and_wrap(obj)) + bs = BitStream() + bs.write_string(pt) + bs.multiple_of(self.test_group.p.nbits() - 1, True) + ct = bs.map(self.sk.public_key.encrypt, self.test_group.p.nbits() - 1) + bs2 = BitStream.unmap(ct, self.sk.decrypt, self.test_group.p.nbits() - 1) + m = bs2.read_string() + obj2 = EosObject.deserialise_and_unwrap(EosObject.from_json(m)) + self.assertEqualJSON(obj, obj2)