diff options
Diffstat (limited to 'pskc/crypto')
-rw-r--r-- | pskc/crypto/aeskw.py | 22 | ||||
-rw-r--r-- | pskc/crypto/tripledeskw.py | 10 |
2 files changed, 18 insertions, 14 deletions
diff --git a/pskc/crypto/aeskw.py b/pskc/crypto/aeskw.py index 24e90b0..eeafed1 100644 --- a/pskc/crypto/aeskw.py +++ b/pskc/crypto/aeskw.py @@ -1,7 +1,7 @@ # aeskw.py - implementation of AES key wrapping # coding: utf-8 # -# Copyright (C) 2014 Arthur de Jong +# Copyright (C) 2014-2015 Arthur de Jong # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -20,6 +20,8 @@ """Implement key wrapping as described in RFC 3394 and RFC 5649.""" +import binascii + from Crypto.Cipher import AES from Crypto.Util.number import bytes_to_long, long_to_bytes from Crypto.Util.strxor import strxor @@ -31,8 +33,8 @@ def _split(value): return value[:8], value[8:] -RFC3394_IV = 'a6a6a6a6a6a6a6a6'.decode('hex') -RFC5649_IV = 'a65959a6'.decode('hex') +RFC3394_IV = binascii.a2b_hex('a6a6a6a6a6a6a6a6') +RFC5649_IV = binascii.a2b_hex('a65959a6') def wrap(plaintext, key, iv=None, pad=None): @@ -54,7 +56,7 @@ def wrap(plaintext, key, iv=None, pad=None): raise EncryptionError('Plaintext length wrong') if mli % 8 != 0 and pad is not False: r = (mli + 7) // 8 - plaintext += ((r * 8) - mli) * '\0' + plaintext += ((r * 8) - mli) * b'\0' if iv is None: if len(plaintext) != mli or pad is True: @@ -63,7 +65,7 @@ def wrap(plaintext, key, iv=None, pad=None): iv = RFC3394_IV encrypt = AES.new(key).encrypt - n = len(plaintext) / 8 + n = len(plaintext) // 8 if n == 1: # RFC 5649 shortcut @@ -76,7 +78,7 @@ def wrap(plaintext, key, iv=None, pad=None): for i in range(n): A, R[i] = _split(encrypt(A + R[i])) A = strxor(A, long_to_bytes(n * j + i + 1, 8)) - return A + ''.join(R) + return A + b''.join(R) def unwrap(ciphertext, key, iv=None, pad=None): @@ -95,7 +97,7 @@ def unwrap(ciphertext, key, iv=None, pad=None): raise DecryptionError('Ciphertext length wrong') decrypt = AES.new(key).decrypt - n = len(ciphertext) / 8 - 1 + n = len(ciphertext) // 8 - 1 if n == 1: A, plaintext = _split(decrypt(ciphertext)) @@ -107,16 +109,16 @@ def unwrap(ciphertext, key, iv=None, pad=None): for i in reversed(range(n)): A = strxor(A, long_to_bytes(n * j + i + 1, 8)) A, R[i] = _split(decrypt(A + R[i])) - plaintext = ''.join(R) + plaintext = b''.join(R) if iv is None: if A == RFC3394_IV and pad is not True: return plaintext elif A[:4] == RFC5649_IV and pad is not False: mli = bytes_to_long(A[4:]) - # check padding length is valid and only contains zeros + # check padding length is valid and plaintext only contains zeros if 8 * (n - 1) < mli <= 8 * n and \ - all(x == '\0' for x in plaintext[mli:]): + plaintext.endswith((len(plaintext) - mli) * b'\0'): return plaintext[:mli] elif A == iv: return plaintext diff --git a/pskc/crypto/tripledeskw.py b/pskc/crypto/tripledeskw.py index 47c93f1..a135ebd 100644 --- a/pskc/crypto/tripledeskw.py +++ b/pskc/crypto/tripledeskw.py @@ -1,7 +1,7 @@ # tripledeskw.py - implementation of Triple DES key wrapping # coding: utf-8 # -# Copyright (C) 2014 Arthur de Jong +# Copyright (C) 2014-2015 Arthur de Jong # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -20,6 +20,8 @@ """Implement Triple DES key wrapping as described in RFC 3217.""" +import binascii + from Crypto import Random from Crypto.Cipher import DES3 from Crypto.Hash import SHA @@ -32,7 +34,7 @@ def _cms_hash(value): return SHA.new(value).digest()[:8] -RFC3217_IV = '4adda22c79e82105'.decode('hex') +RFC3217_IV = binascii.a2b_hex('4adda22c79e82105') def wrap(plaintext, key, iv=None): @@ -48,7 +50,7 @@ def wrap(plaintext, key, iv=None): cipher = DES3.new(key, DES3.MODE_CBC, iv) tmp = iv + cipher.encrypt(plaintext + _cms_hash(plaintext)) cipher = DES3.new(key, DES3.MODE_CBC, RFC3217_IV) - return cipher.encrypt(''.join(reversed(tmp))) + return cipher.encrypt(tmp[::-1]) def unwrap(ciphertext, key): @@ -59,7 +61,7 @@ def unwrap(ciphertext, key): if len(ciphertext) % DES3.block_size != 0: raise DecryptionError('Ciphertext length wrong') cipher = DES3.new(key, DES3.MODE_CBC, RFC3217_IV) - tmp = ''.join(reversed(cipher.decrypt(ciphertext))) + tmp = cipher.decrypt(ciphertext)[::-1] cipher = DES3.new(key, DES3.MODE_CBC, tmp[:8]) tmp = cipher.decrypt(tmp[8:]) if tmp[-8:] == _cms_hash(tmp[:-8]): |