Implement easy block encryption/decryption

This commit is contained in:
Yingtong Li 2017-09-26 21:47:32 +10:00
parent 2c781ab778
commit 00ac2f96aa
Signed by: RunasSudo
GPG Key ID: 7234E476BF21C61A
4 changed files with 66 additions and 9 deletions

View File

@ -25,6 +25,13 @@ class EosTestCase:
def setUpClass(cls): def setUpClass(cls):
pass pass
def assertTrue(self, a):
if is_python:
self.impl.assertTrue(a)
else:
if not a:
raise Error('Assertion failed: ' + str(a) + ' not True')
def assertEqual(self, a, b): def assertEqual(self, a, b):
if is_python: if is_python:
self.impl.assertEqual(a, b) self.impl.assertEqual(a, b)

View File

@ -32,21 +32,53 @@ class BitStream(EosObject):
self.ptr = ptr self.ptr = ptr
self.remaining = self.nbits - self.ptr self.remaining = self.nbits - self.ptr
def read(self, nbits): def read(self, nbits=None):
# 11000110110 # 11000110110
# ^---- # ^----
if nbits is None:
nbits = self.remaining
if nbits > self.remaining:
nbits = self.remaining
val = (self.impl >> (self.remaining - nbits)) & ((ONE << nbits) - ONE) val = (self.impl >> (self.remaining - nbits)) & ((ONE << nbits) - ONE)
self.ptr += nbits self.ptr += nbits
self.remaining -= nbits self.remaining -= nbits
return val return val
def write(self, bits): def write(self, bits, nbits=None):
# 11 0100110 # 11 0100110
# 10010 # 10010
# ^---- # ^----
self.impl = ((self.impl >> self.remaining) << (self.remaining + bits.nbits())) | (bits << self.remaining) | (self.impl & ((ONE << self.remaining) - 1)) if nbits is None:
self.ptr += bits.nbits() nbits = bits.nbits()
self.nbits += bits.nbits()
self.impl = ((self.impl >> self.remaining) << (self.remaining + nbits)) | (bits << self.remaining) | (self.impl & ((ONE << self.remaining) - 1))
self.ptr += nbits
self.nbits += nbits
# Make the size of this BitStream a multiple of the block_size
def multiple_of(self, block_size):
if self.nbits % block_size != 0:
self.nbits += (block_size - (self.nbits % block_size))
return self # For convenient chaining
def map(self, func, block_size):
if self.nbits % block_size != 0:
raise Exception('The size of the BitStream must be a multiple of block_size')
self.seek(0)
result = []
while self.remaining > 0:
result.append(func(self.read(block_size)))
return result
@classmethod
def unmap(cls, value, func, block_size):
bs = cls()
for x in value:
bs.write(func(x), block_size)
bs.seek(0)
return bs
def serialise(self): def serialise(self):
return self.impl return self.impl

View File

@ -56,13 +56,13 @@ class EGPrivateKey(EmbeddedObject):
# HAC 8.17 # HAC 8.17
@staticmethod @staticmethod
def generate(): def generate(group=DEFAULT_GROUP):
# Choose an element 1 <= x <= p - 2 # Choose an element 1 <= x <= p - 2
x = BigInt.crypto_random(ONE, DEFAULT_GROUP.p - TWO) x = BigInt.crypto_random(ONE, group.p - TWO)
# Calculate the public key as G^x # Calculate the public key as G^x
X = pow(DEFAULT_GROUP.g, x, DEFAULT_GROUP.p) X = pow(group.g, x, group.p)
pk = EGPublicKey(group=DEFAULT_GROUP, X=X) pk = EGPublicKey(group=group, X=X)
sk = EGPrivateKey(public_key=pk, x=x) sk = EGPrivateKey(public_key=pk, x=x)
return sk return sk

View File

@ -28,6 +28,17 @@ class EGTestCase(EosTestCase):
m = sk.decrypt(ct) m = sk.decrypt(ct)
self.assertEqualJSON(pt, m) self.assertEqualJSON(pt, m)
def test_eg_block(self):
test_group = CyclicGroup(p=BigInt('11'), g=BigInt('2'))
pt = BigInt('11010010011111010100101', 2)
sk = EGPrivateKey.generate(test_group)
ct = BitStream(pt).multiple_of(test_group.p.nbits() - 1).map(sk.public_key.encrypt, test_group.p.nbits() - 1)
for i in range(len(ct)):
self.assertTrue(ct[i].gamma < test_group.p)
self.assertTrue(ct[i].delta < test_group.p)
m = BitStream.unmap(ct, sk.decrypt, test_group.p.nbits() - 1).read()
self.assertEqualJSON(pt, m)
class BitStreamTestCase(EosTestCase): class BitStreamTestCase(EosTestCase):
def test_bitstream(self): def test_bitstream(self):
bs = BitStream(BigInt('100101011011', 2)) bs = BitStream(BigInt('100101011011', 2))
@ -47,3 +58,10 @@ class BitStreamTestCase(EosTestCase):
self.assertEqual(bs.read(4), 0b1101) self.assertEqual(bs.read(4), 0b1101)
self.assertEqual(bs.read(4), 0b0110) self.assertEqual(bs.read(4), 0b0110)
self.assertEqual(bs.read(2), 0b11) self.assertEqual(bs.read(2), 0b11)
def test_bitstream_map(self):
bs = BitStream(BigInt('100101011011', 2))
result = bs.map(lambda x: x, 4)
expect = [0b1001, 0b0101, 0b1011]
for i in range(len(expect)):
self.assertEqual(result[i], expect[i])