Skip to content

Commit

Permalink
Added Python 3 compatibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
ricmoo committed Jun 20, 2014
1 parent 165a3b0 commit 35b7702
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 56 deletions.
95 changes: 67 additions & 28 deletions pyaes/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,32 @@ def _compact_word(word):
def _string_to_bytes(text):
return list(ord(c) for c in text)

def _bytes_to_string(text):
return "".join(chr(v) for v in text)
def _bytes_to_string(binary):
return "".join(chr(b) for b in binary)

def _concat_list(a, b):
return a + b


# Python 3 compatibility
try:
xrange
except Exception:
xrange = range

# Python 3 supports bytes, which is already an array of integers
def _string_to_bytes(text):
if isinstance(text, bytes):
return text
return [ord(c) for c in text]

# In Python 3, we return bytes
def _bytes_to_string(binary):
return bytes(binary)

# Python 3 cannot concatenate a list onto a bytes, so we bytes-ify it first
def _concat_list(a, b):
return a + bytes(b)


# Based *largely* on the Rijndael implementation
Expand Down Expand Up @@ -118,15 +142,15 @@ def __init__(self, key):
self._Kd = [[0] * 4 for i in xrange(rounds + 1)]

round_key_count = (rounds + 1) * 4
KC = len(key) / 4
KC = len(key) // 4

# Convert the key into ints
tk = [ struct.unpack('>i', key[i:i + 4])[0] for i in xrange(0, len(key), 4) ]

# Copy values into round key arrays
for i in xrange(0, KC):
self._Ke[i / 4][i % 4] = tk[i]
self._Kd[rounds - (i / 4)][i % 4] = tk[i]
self._Ke[i // 4][i % 4] = tk[i]
self._Kd[rounds - (i // 4)][i % 4] = tk[i]

# Key expansion (fips-197 section 5.2)
rconpointer = 0
Expand All @@ -147,23 +171,23 @@ def __init__(self, key):

# Key expansion for 256-bit keys is "slightly different" (fips-197)
else:
for i in xrange(1, KC / 2):
for i in xrange(1, KC // 2):
tk[i] ^= tk[i - 1]
tt = tk[KC / 2 - 1]
tt = tk[KC // 2 - 1]

tk[KC / 2] ^= (self.S[ tt & 0xFF] ^
(self.S[(tt >> 8) & 0xFF] << 8) ^
(self.S[(tt >> 16) & 0xFF] << 16) ^
(self.S[(tt >> 24) & 0xFF] << 24))
tk[KC // 2] ^= (self.S[ tt & 0xFF] ^
(self.S[(tt >> 8) & 0xFF] << 8) ^
(self.S[(tt >> 16) & 0xFF] << 16) ^
(self.S[(tt >> 24) & 0xFF] << 24))

for i in xrange(KC / 2 + 1, KC):
tk[i] ^= tk[i-1]
for i in xrange(KC // 2 + 1, KC):
tk[i] ^= tk[i - 1]

# Copy values into round key arrays
j = 0
while j < KC and t < round_key_count:
self._Ke[t / 4][t % 4] = tk[j]
self._Kd[rounds - (t / 4)][t % 4] = tk[j]
self._Ke[t // 4][t % 4] = tk[j]
self._Kd[rounds - (t // 4)][t % 4] = tk[j]
j += 1
t += 1

Expand Down Expand Up @@ -317,13 +341,15 @@ def encrypt(self, plaintext):
if len(plaintext) != 16:
raise ValueError('plaintext block must be 16 bytes')

return _bytes_to_string(self._aes.encrypt(_string_to_bytes(plaintext)))
plaintext = _string_to_bytes(plaintext)
return _bytes_to_string(self._aes.encrypt(plaintext))

def decrypt(self, ciphertext):
if len(ciphertext) != 16:
raise ValueError('ciphertext block must be 16 bytes')

return _bytes_to_string(self._aes.decrypt(_string_to_bytes(ciphertext)))
ciphertext = _string_to_bytes(ciphertext)
return _bytes_to_string(self._aes.decrypt(ciphertext))



Expand Down Expand Up @@ -362,7 +388,8 @@ def encrypt(self, plaintext):
if len(plaintext) != 16:
raise ValueError('plaintext block must be 16 bytes')

precipherblock = [ (ord(p) ^ l) for (p, l) in zip(plaintext, self._last_cipherblock) ]
plaintext = _string_to_bytes(plaintext)
precipherblock = [ (p ^ l) for (p, l) in zip(plaintext, self._last_cipherblock) ]
self._last_cipherblock = self._aes.encrypt(precipherblock)

return _bytes_to_string(self._last_cipherblock)
Expand All @@ -372,10 +399,10 @@ def decrypt(self, ciphertext):
raise ValueError('ciphertext block must be 16 bytes')

cipherblock = _string_to_bytes(ciphertext)
plaintext = "".join([ chr(p ^ l) for (p, l) in zip(self._aes.decrypt(cipherblock), self._last_cipherblock) ])
plaintext = [ (p ^ l) for (p, l) in zip(self._aes.decrypt(cipherblock), self._last_cipherblock) ]
self._last_cipherblock = cipherblock

return plaintext
return _bytes_to_string(plaintext)



Expand Down Expand Up @@ -412,15 +439,17 @@ def encrypt(self, plaintext):
if len(plaintext) % self._segment_bytes != 0:
raise ValueError('plaintext block must be a multiple of segment_size')

plaintext = _string_to_bytes(plaintext)

# Break block into segments
encrypted = [ ]
for i in xrange(0, len(plaintext), self._segment_bytes):
plaintext_segment = plaintext[i: i + self._segment_bytes]
xor_segment = self._aes.encrypt(self._shift_register)[:len(plaintext_segment)]
cipher_segment = [ (ord(p) ^ x) for (p, x) in zip(plaintext_segment, xor_segment) ]
cipher_segment = [ (p ^ x) for (p, x) in zip(plaintext_segment, xor_segment) ]

# Shift the top bits out and the ciphertext in
self._shift_register = self._shift_register[len(cipher_segment):] + cipher_segment
self._shift_register = _concat_list(self._shift_register[len(cipher_segment):], cipher_segment)

encrypted.extend(cipher_segment)

Expand All @@ -430,15 +459,17 @@ def decrypt(self, ciphertext):
if len(ciphertext) % self._segment_bytes != 0:
raise ValueError('ciphertext block must be a multiple of segment_size')

ciphertext = _string_to_bytes(ciphertext)

# Break block into segments
decrypted = [ ]
for i in xrange(0, len(ciphertext), self._segment_bytes):
cipher_segment = _string_to_bytes(ciphertext[i: i + self._segment_bytes])
cipher_segment = ciphertext[i: i + self._segment_bytes]
xor_segment = self._aes.encrypt(self._shift_register)[:len(cipher_segment)]
plaintext_segment = [ (p ^ x) for (p, x) in zip(cipher_segment, xor_segment) ]

# Shift the top bits out and the ciphertext in
self._shift_register = self._shift_register[len(cipher_segment):] + cipher_segment
self._shift_register = _concat_list(self._shift_register[len(cipher_segment):], cipher_segment)

decrypted.extend(plaintext_segment)

Expand Down Expand Up @@ -476,13 +507,13 @@ def __init__(self, key, iv = None):

def encrypt(self, plaintext):
encrypted = [ ]
for c in plaintext:
for p in _string_to_bytes(plaintext):
if len(self._remaining_block) == 0:
self._remaining_block = self._aes.encrypt(self._last_precipherblock)
self._last_precipherblock = [ ]
precipherbyte = self._remaining_block.pop(0)
self._last_precipherblock.append(precipherbyte)
cipherbyte = ord(c) ^ precipherbyte
cipherbyte = p ^ precipherbyte
encrypted.append(cipherbyte)

return _bytes_to_string(encrypted)
Expand Down Expand Up @@ -536,7 +567,9 @@ def encrypt(self, plaintext):
self._remaining_counter += self._aes.encrypt(self._counter.value)
self._counter.increment()

encrypted = [ (ord(p) ^ c) for (p, c) in zip(plaintext, self._remaining_counter) ]
plaintext = _string_to_bytes(plaintext)

encrypted = [ (p ^ c) for (p, c) in zip(plaintext, self._remaining_counter) ]
self._remaining_counter = self._remaining_counter[len(encrypted):]

return _bytes_to_string(encrypted)
Expand All @@ -547,4 +580,10 @@ def decrypt(self, crypttext):


# Simple lookup table for each mode
AESModesOfOperation = dict(ctr = AESModeOfOperationCTR, cbc = AESModeOfOperationCBC, cfb = AESModeOfOperationCFB, ecb = AESModeOfOperationECB, ofb = AESModeOfOperationOFB)
AESModesOfOperation = dict(
ctr = AESModeOfOperationCTR,
cbc = AESModeOfOperationCBC,
cfb = AESModeOfOperationCFB,
ecb = AESModeOfOperationECB,
ofb = AESModeOfOperationOFB,
)
17 changes: 10 additions & 7 deletions pyaes/blockfeeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
# THE SOFTWARE.


from pyaes.aes import AESBlockModeOfOperation, AESSegmentModeOfOperation, AESStreamModeOfOperation
from pyaes.util import append_PKCS7_padding, strip_PKCS7_padding
from .aes import AESBlockModeOfOperation, AESSegmentModeOfOperation, AESStreamModeOfOperation
from .util import append_PKCS7_padding, strip_PKCS7_padding, to_bufferable


# First we inject three functions to each of the modes of operations
#
Expand Down Expand Up @@ -71,12 +72,14 @@ def _segment_can_consume(self, size):

# CFB can handle a non-segment-sized block at the end using the remaining cipherblock
def _segment_final_encrypt(self, data):
padded = data + (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
faux_padding = (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + to_bufferable(faux_padding)
return self.encrypt(padded)[:len(data)]

# CFB can handle a non-segment-sized block at the end using the remaining cipherblock
def _segment_final_decrypt(self, data):
padded = data + (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
faux_padding = (chr(0) * (self.segment_bytes - (len(data) % self.segment_bytes)))
padded = data + to_bufferable(faux_padding)
return self.decrypt(padded)[:len(data)]

AESSegmentModeOfOperation._can_consume = _segment_can_consume
Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(self, mode, feed, final):
self._mode = mode
self._feed = feed
self._final = final
self._buffer = ""
self._buffer = to_bufferable("")

def feed(self, data = None):
'''Provide bytes to encrypt (or decrypt), returning any bytes
Expand All @@ -130,10 +133,10 @@ def feed(self, data = None):
self._buffer = None
return result

self._buffer += data
self._buffer += to_bufferable(data)

# We keep 16 bytes around so we can determine padding
result = ''
result = to_bufferable('')
while len(self._buffer) > 16:
can_consume = self._mode._can_consume(len(self._buffer) - 16)
if can_consume == 0: break
Expand Down
27 changes: 25 additions & 2 deletions pyaes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,39 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Why to_bufferable?
# Python 3 is very different from Python 2.x when it comes to strings of text
# and strings of bytes; in Python 3, strings of bytes do not exist, instead to
# represent arbitrary binary data, we must use the "bytes" object. This method
# ensures the object behaves as we need it to.

def to_bufferable(binary):
return binary

def _get_byte(c):
return ord(c)

try:
xrange
except:

def to_bufferable(binary):
if isinstance(binary, bytes):
return binary
return bytes(ord(b) for b in binary)

def _get_byte(c):
return c

def append_PKCS7_padding(data):
pad = 16 - (len(data) % 16)
return data + pad * chr(pad)
return data + to_bufferable(chr(pad) * pad)

def strip_PKCS7_padding(data):
if len(data) % 16 != 0:
raise ValueError("invalid length")

pad = ord(data[-1])
pad = _get_byte(data[-1])

if pad > 16:
raise ValueError("invalid padding byte")
Expand Down
22 changes: 7 additions & 15 deletions tests/test-aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,36 +125,28 @@
count += 1

t0 = time.time()
kenc = "".join(str(kaes.encrypt(p)) for p in plaintext)
kenc = [kaes.encrypt(p) for p in plaintext]
tt_kencrypt += time.time() - t0

t0 = time.time()
enc = "".join(str(aes.encrypt(p)) for p in plaintext)
enc = [aes.encrypt(p) for p in plaintext]
tt_encrypt += time.time() - t0

if kenc != enc:
print(repr((kenc, enc)))
print("Test: mode=%s operation=encrypt key_size=%d text_length=%d trial=%d" % (mode, key_size, len(plaintext), test))
raise Exception('Failed encypt test case')

dec = [ ]
index = 0
for p in plaintext:
dec.append(kenc[index:index + len(p)])
index += len(p)
pt = ''.join(str(p) for p in plaintext)
raise Exception('Failed encypt test case (%s)' % mode)

t0 = time.time()
dt = "".join(str(kaes2.decrypt(k)) for k in dec)
dt1 = [kaes2.decrypt(k) for k in kenc]
tt_kdecrypt += time.time() - t0

t0 = time.time()
dt = "".join(str(aes2.decrypt(k)) for k in dec)
dt2 = [aes2.decrypt(k) for k in kenc]
tt_decrypt += time.time() - t0

if pt != dt:
if plaintext != dt2:
print("Test: mode=%s operation=decrypt key_size=%d text_length=%d trial=%d" % (mode, key_size, len(plaintext), test))
raise Exception('Failed decypt test case')
raise Exception('Failed decypt test case (%s)' % mode)

better = (tt_setup + tt_encrypt + tt_decrypt) / (tt_ksetup + tt_kencrypt + tt_kdecrypt)
print("Mode: %s" % mode)
Expand Down
11 changes: 7 additions & 4 deletions tests/test-blockfeeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,23 @@
import pyaes
from pyaes.blockfeeder import Decrypter, Encrypter

from pyaes.util import to_bufferable


key = os.urandom(32)

plaintext = os.urandom(1000)

for mode_name in pyaes.AESModesOfOperation:
mode = pyaes.AESModesOfOperation[mode_name]
print mode.name
print(mode.name)

kw = dict(key = key)
if mode_name in ('cbc', 'cfb', 'ofb'):
kw['iv'] = os.urandom(16)

encrypter = Encrypter(mode(**kw))
ciphertext = ''
ciphertext = to_bufferable('')

# Feed the encrypter random number of bytes at a time
index = 0
Expand All @@ -55,7 +58,7 @@
ciphertext += encrypter.feed(None)

decrypter = Decrypter(mode(**kw))
decrypted = ''
decrypted = to_bufferable('')

# Feed the decrypter random number of bytes at a time
index = 0
Expand All @@ -68,4 +71,4 @@

passed = decrypted == plaintext
cipher_length = len(ciphertext)
print " cipher-length=%(cipher_length)s passed=%(passed)s" % locals()
print(" cipher-length=%(cipher_length)s passed=%(passed)s" % locals())
Loading

0 comments on commit 35b7702

Please sign in to comment.