mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Various small fixes and clean ups
This commit is contained in:
parent
c1834322b6
commit
53be0708d7
5 changed files with 63 additions and 34 deletions
|
|
@ -23,7 +23,7 @@
|
||||||
from Crypto.Signature.pss import MGF1
|
from Crypto.Signature.pss import MGF1
|
||||||
import Crypto.Hash.SHA1
|
import Crypto.Hash.SHA1
|
||||||
|
|
||||||
from Crypto.Util.py3compat import *
|
from Crypto.Util.py3compat import bord, _copy_bytes
|
||||||
import Crypto.Util.number
|
import Crypto.Util.number
|
||||||
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
||||||
from Crypto.Util.strxor import strxor
|
from Crypto.Util.strxor import strxor
|
||||||
|
|
@ -70,7 +70,7 @@ class PKCS1OAEP_Cipher:
|
||||||
else:
|
else:
|
||||||
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
|
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
|
||||||
|
|
||||||
self._label = bstr(label)
|
self._label = _copy_bytes(None, None, label)
|
||||||
self._randfunc = randfunc
|
self._randfunc = randfunc
|
||||||
|
|
||||||
def can_encrypt(self):
|
def can_encrypt(self):
|
||||||
|
|
@ -102,24 +102,23 @@ class PKCS1OAEP_Cipher:
|
||||||
:raises ValueError:
|
:raises ValueError:
|
||||||
if the message is too long.
|
if the message is too long.
|
||||||
"""
|
"""
|
||||||
# TODO: Verify the key is RSA
|
|
||||||
|
|
||||||
# See 7.1.1 in RFC3447
|
# See 7.1.1 in RFC3447
|
||||||
modBits = Crypto.Util.number.size(self._key.n)
|
modBits = Crypto.Util.number.size(self._key.n)
|
||||||
k = ceil_div(modBits,8) # Convert from bits to bytes
|
k = ceil_div(modBits, 8) # Convert from bits to bytes
|
||||||
hLen = self._hashObj.digest_size
|
hLen = self._hashObj.digest_size
|
||||||
mLen = len(message)
|
mLen = len(message)
|
||||||
|
|
||||||
# Step 1b
|
# Step 1b
|
||||||
ps_len = k-mLen-2*hLen-2
|
ps_len = k - mLen - 2 * hLen - 2
|
||||||
if ps_len<0:
|
if ps_len < 0:
|
||||||
raise ValueError("Plaintext is too long.")
|
raise ValueError("Plaintext is too long.")
|
||||||
# Step 2a
|
# Step 2a
|
||||||
lHash = self._hashObj.new(self._label).digest()
|
lHash = self._hashObj.new(self._label).digest()
|
||||||
# Step 2b
|
# Step 2b
|
||||||
ps = bchr(0x00)*ps_len
|
ps = b'\x00' * ps_len
|
||||||
# Step 2c
|
# Step 2c
|
||||||
db = lHash + ps + bchr(0x01) + message
|
db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
|
||||||
# Step 2d
|
# Step 2d
|
||||||
ros = self._randfunc(hLen)
|
ros = self._randfunc(hLen)
|
||||||
# Step 2e
|
# Step 2e
|
||||||
|
|
@ -131,7 +130,7 @@ class PKCS1OAEP_Cipher:
|
||||||
# Step 2h
|
# Step 2h
|
||||||
maskedSeed = strxor(ros, seedMask)
|
maskedSeed = strxor(ros, seedMask)
|
||||||
# Step 2i
|
# Step 2i
|
||||||
em = bchr(0x00) + maskedSeed + maskedDB
|
em = b'\x00' + maskedSeed + maskedDB
|
||||||
# Step 3a (OS2IP)
|
# Step 3a (OS2IP)
|
||||||
em_int = bytes_to_long(em)
|
em_int = bytes_to_long(em)
|
||||||
# Step 3b (RSAEP)
|
# Step 3b (RSAEP)
|
||||||
|
|
@ -167,7 +166,7 @@ class PKCS1OAEP_Cipher:
|
||||||
if len(ciphertext) != k or k<hLen+2:
|
if len(ciphertext) != k or k<hLen+2:
|
||||||
raise ValueError("Ciphertext with incorrect length.")
|
raise ValueError("Ciphertext with incorrect length.")
|
||||||
# Step 2a (O2SIP)
|
# Step 2a (O2SIP)
|
||||||
ct_int = bytes_to_long(bstr(ciphertext))
|
ct_int = bytes_to_long(ciphertext)
|
||||||
# Step 2b (RSADP)
|
# Step 2b (RSADP)
|
||||||
m_int = self._key._decrypt(ct_int)
|
m_int = self._key._decrypt(ct_int)
|
||||||
# Complete step 2c (I2OSP)
|
# Complete step 2c (I2OSP)
|
||||||
|
|
@ -190,20 +189,20 @@ class PKCS1OAEP_Cipher:
|
||||||
db = strxor(maskedDB, dbMask)
|
db = strxor(maskedDB, dbMask)
|
||||||
# Step 3g
|
# Step 3g
|
||||||
valid = 1
|
valid = 1
|
||||||
one = db[hLen:].find(bchr(0x01))
|
one = db[hLen:].find(b'\x01')
|
||||||
lHash1 = db[:hLen]
|
lHash1 = db[:hLen]
|
||||||
if lHash1!=lHash:
|
if lHash1!=lHash:
|
||||||
valid = 0
|
valid = 0
|
||||||
if one<0:
|
if one<0:
|
||||||
valid = 0
|
valid = 0
|
||||||
if bord(y)!=0:
|
if bord(y) != 0:
|
||||||
valid = 0
|
valid = 0
|
||||||
if not valid:
|
if not valid:
|
||||||
raise ValueError("Incorrect decryption.")
|
raise ValueError("Incorrect decryption.")
|
||||||
# Step 4
|
# Step 4
|
||||||
return db[hLen+one+1:]
|
return db[hLen+one+1:]
|
||||||
|
|
||||||
def new(key, hashAlgo=None, mgfunc=None, label=b(''), randfunc=None):
|
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
|
||||||
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
|
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
|
||||||
|
|
||||||
:param key:
|
:param key:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
__all__ = [ 'new', 'PKCS115_Cipher' ]
|
__all__ = [ 'new', 'PKCS115_Cipher' ]
|
||||||
|
|
||||||
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
||||||
from Crypto.Util.py3compat import *
|
from Crypto.Util.py3compat import bord, _copy_bytes
|
||||||
import Crypto.Util.number
|
import Crypto.Util.number
|
||||||
from Crypto import Random
|
from Crypto import Random
|
||||||
|
|
||||||
|
|
@ -79,7 +79,7 @@ class PKCS115_Cipher:
|
||||||
mLen = len(message)
|
mLen = len(message)
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
if mLen > k-11:
|
if mLen > k - 11:
|
||||||
raise ValueError("Plaintext is too long.")
|
raise ValueError("Plaintext is too long.")
|
||||||
# Step 2a
|
# Step 2a
|
||||||
ps = []
|
ps = []
|
||||||
|
|
@ -88,10 +88,10 @@ class PKCS115_Cipher:
|
||||||
if bord(new_byte[0]) == 0x00:
|
if bord(new_byte[0]) == 0x00:
|
||||||
continue
|
continue
|
||||||
ps.append(new_byte)
|
ps.append(new_byte)
|
||||||
ps = b("").join(ps)
|
ps = b"".join(ps)
|
||||||
assert(len(ps) == k - mLen - 3)
|
assert(len(ps) == k - mLen - 3)
|
||||||
# Step 2b
|
# Step 2b
|
||||||
em = b('\x00\x02') + ps + bchr(0x00) + bstr(message)
|
em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
|
||||||
# Step 3a (OS2IP)
|
# Step 3a (OS2IP)
|
||||||
em_int = bytes_to_long(em)
|
em_int = bytes_to_long(em)
|
||||||
# Step 3b (RSAEP)
|
# Step 3b (RSAEP)
|
||||||
|
|
@ -164,17 +164,17 @@ class PKCS115_Cipher:
|
||||||
if len(ciphertext) != k:
|
if len(ciphertext) != k:
|
||||||
raise ValueError("Ciphertext with incorrect length.")
|
raise ValueError("Ciphertext with incorrect length.")
|
||||||
# Step 2a (O2SIP)
|
# Step 2a (O2SIP)
|
||||||
ct_int = bytes_to_long(bstr(ciphertext))
|
ct_int = bytes_to_long(ciphertext)
|
||||||
# Step 2b (RSADP)
|
# Step 2b (RSADP)
|
||||||
m_int = self._key._decrypt(ct_int)
|
m_int = self._key._decrypt(ct_int)
|
||||||
# Complete step 2c (I2OSP)
|
# Complete step 2c (I2OSP)
|
||||||
em = long_to_bytes(m_int, k)
|
em = long_to_bytes(m_int, k)
|
||||||
# Step 3
|
# Step 3
|
||||||
sep = em.find(bchr(0x00),2)
|
sep = em.find(b'\x00', 2)
|
||||||
if not em.startswith(b('\x00\x02')) or sep<10:
|
if not em.startswith(b'\x00\x02') or sep < 10:
|
||||||
return sentinel
|
return sentinel
|
||||||
# Step 4
|
# Step 4
|
||||||
return em[sep+1:]
|
return em[sep + 1:]
|
||||||
|
|
||||||
|
|
||||||
def new(key, randfunc=None):
|
def new(key, randfunc=None):
|
||||||
|
|
|
||||||
|
|
@ -160,14 +160,14 @@ HKukWBcq9f/UOmS0oEhai/6g+Uf7VHJdWaeO5LzuvwU=
|
||||||
self.assertEqual(pt,pt2)
|
self.assertEqual(pt,pt2)
|
||||||
|
|
||||||
def testByteArray(self):
|
def testByteArray(self):
|
||||||
pt = b("XER")
|
pt = b"XER"
|
||||||
cipher = PKCS.new(self.key1024)
|
cipher = PKCS.new(self.key1024)
|
||||||
ct = cipher.encrypt(bytearray(pt))
|
ct = cipher.encrypt(bytearray(pt))
|
||||||
pt2 = cipher.decrypt(bytearray(ct), "---")
|
pt2 = cipher.decrypt(bytearray(ct), "---")
|
||||||
self.assertEqual(pt, pt2)
|
self.assertEqual(pt, pt2)
|
||||||
|
|
||||||
def testMemoryview(self):
|
def testMemoryview(self):
|
||||||
pt = b("XER")
|
pt = b"XER"
|
||||||
cipher = PKCS.new(self.key1024)
|
cipher = PKCS.new(self.key1024)
|
||||||
ct = cipher.encrypt(memoryview(bytearray(pt)))
|
ct = cipher.encrypt(memoryview(bytearray(pt)))
|
||||||
pt2 = cipher.decrypt(memoryview(bytearray(ct)), "---")
|
pt2 = cipher.decrypt(memoryview(bytearray(ct)), "---")
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from binascii import unhexlify, hexlify
|
from binascii import unhexlify, hexlify
|
||||||
|
|
||||||
|
from Crypto.Util.py3compat import _memoryview
|
||||||
from Crypto.SelfTest.st_common import list_test_cases
|
from Crypto.SelfTest.st_common import list_test_cases
|
||||||
from Crypto.Util.strxor import strxor, strxor_c
|
from Crypto.Util.strxor import strxor, strxor_c
|
||||||
|
|
||||||
|
|
@ -61,16 +62,27 @@ class StrxorTests(unittest.TestCase):
|
||||||
term2 = unhexlify(b"ff339a83e5cd4cdf564990")
|
term2 = unhexlify(b"ff339a83e5cd4cdf564990")
|
||||||
self.assertRaises(ValueError, strxor, term1, term2)
|
self.assertRaises(ValueError, strxor, term1, term2)
|
||||||
|
|
||||||
def test_bytearray_memoryview(self):
|
def test_bytearray(self):
|
||||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||||
term1_ba = bytearray(term1)
|
term1_ba = bytearray(term1)
|
||||||
term1_mv = memoryview(term1)
|
|
||||||
term2 = unhexlify(b"383d4ba020573314395b")
|
term2 = unhexlify(b"383d4ba020573314395b")
|
||||||
result = unhexlify(b"c70ed123c59a7fcb6f12")
|
result = unhexlify(b"c70ed123c59a7fcb6f12")
|
||||||
|
|
||||||
self.assertEqual(strxor(term1_ba, term2), result)
|
self.assertEqual(strxor(term1_ba, term2), result)
|
||||||
|
|
||||||
|
def test_memoryview(self):
|
||||||
|
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||||
|
term1_mv = memoryview(term1)
|
||||||
|
term2 = unhexlify(b"383d4ba020573314395b")
|
||||||
|
result = unhexlify(b"c70ed123c59a7fcb6f12")
|
||||||
|
|
||||||
self.assertEqual(strxor(term1_mv, term2), result)
|
self.assertEqual(strxor(term1_mv, term2), result)
|
||||||
|
|
||||||
|
import types
|
||||||
|
if _memoryview is types.NoneType:
|
||||||
|
del test_memoryview
|
||||||
|
|
||||||
|
|
||||||
class Strxor_cTests(unittest.TestCase):
|
class Strxor_cTests(unittest.TestCase):
|
||||||
|
|
||||||
def test1(self):
|
def test1(self):
|
||||||
|
|
@ -90,21 +102,32 @@ class Strxor_cTests(unittest.TestCase):
|
||||||
self.assertRaises(ValueError, strxor_c, term1, -1)
|
self.assertRaises(ValueError, strxor_c, term1, -1)
|
||||||
self.assertRaises(ValueError, strxor_c, term1, 256)
|
self.assertRaises(ValueError, strxor_c, term1, 256)
|
||||||
|
|
||||||
def test_bytearray_memoryview(self):
|
def test_bytearray(self):
|
||||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||||
term1_ba = bytearray(term1)
|
term1_ba = bytearray(term1)
|
||||||
term1_mv = memoryview(term1)
|
|
||||||
result = unhexlify(b"be72dbc2a48c0d9e1708")
|
result = unhexlify(b"be72dbc2a48c0d9e1708")
|
||||||
|
|
||||||
self.assertEqual(strxor_c(term1_ba, 65), result)
|
self.assertEqual(strxor_c(term1_ba, 65), result)
|
||||||
|
|
||||||
|
def test_memoryview(self):
|
||||||
|
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||||
|
term1_mv = memoryview(term1)
|
||||||
|
result = unhexlify(b"be72dbc2a48c0d9e1708")
|
||||||
|
|
||||||
self.assertEqual(strxor_c(term1_mv, 65), result)
|
self.assertEqual(strxor_c(term1_mv, 65), result)
|
||||||
|
|
||||||
|
import types
|
||||||
|
if _memoryview is types.NoneType:
|
||||||
|
del test_memoryview
|
||||||
|
|
||||||
|
|
||||||
def get_tests(config={}):
|
def get_tests(config={}):
|
||||||
tests = []
|
tests = []
|
||||||
tests += list_test_cases(StrxorTests)
|
tests += list_test_cases(StrxorTests)
|
||||||
tests += list_test_cases(Strxor_cTests)
|
tests += list_test_cases(Strxor_cTests)
|
||||||
return tests
|
return tests
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
suite = lambda: unittest.TestSuite(get_tests())
|
suite = lambda: unittest.TestSuite(get_tests())
|
||||||
unittest.main(defaultTest='suite')
|
unittest.main(defaultTest='suite')
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
import struct
|
||||||
from Crypto import Random
|
from Crypto import Random
|
||||||
from Crypto.Util.py3compat import *
|
from Crypto.Util.py3compat import *
|
||||||
|
|
||||||
|
|
@ -60,8 +61,8 @@ def getRandomInteger(N, randfunc=None):
|
||||||
S = randfunc(N>>3)
|
S = randfunc(N>>3)
|
||||||
odd_bits = N % 8
|
odd_bits = N % 8
|
||||||
if odd_bits != 0:
|
if odd_bits != 0:
|
||||||
char = ord(randfunc(1)) >> (8-odd_bits)
|
rand_bits = ord(randfunc(1)) >> (8-odd_bits)
|
||||||
S = bchr(char) + S
|
S = struct.pack('B', rand_bits) + S
|
||||||
value = bytes_to_long(S)
|
value = bytes_to_long(S)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
@ -380,7 +381,7 @@ def long_to_bytes(n, blocksize=0):
|
||||||
be of minimal length.
|
be of minimal length.
|
||||||
"""
|
"""
|
||||||
# after much testing, this algorithm was deemed to be the fastest
|
# after much testing, this algorithm was deemed to be the fastest
|
||||||
s = b('')
|
s = b''
|
||||||
n = int(n)
|
n = int(n)
|
||||||
pack = struct.pack
|
pack = struct.pack
|
||||||
while n > 0:
|
while n > 0:
|
||||||
|
|
@ -388,17 +389,17 @@ def long_to_bytes(n, blocksize=0):
|
||||||
n = n >> 32
|
n = n >> 32
|
||||||
# strip off leading zeros
|
# strip off leading zeros
|
||||||
for i in range(len(s)):
|
for i in range(len(s)):
|
||||||
if s[i] != b('\000')[0]:
|
if s[i] != b'\x00'[0]:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# only happens when n == 0
|
# only happens when n == 0
|
||||||
s = b('\000')
|
s = b'\x00'
|
||||||
i = 0
|
i = 0
|
||||||
s = s[i:]
|
s = s[i:]
|
||||||
# add back some pad bytes. this could be done more efficiently w.r.t. the
|
# add back some pad bytes. this could be done more efficiently w.r.t. the
|
||||||
# de-padding being done above, but sigh...
|
# de-padding being done above, but sigh...
|
||||||
if blocksize > 0 and len(s) % blocksize:
|
if blocksize > 0 and len(s) % blocksize:
|
||||||
s = (blocksize - len(s) % blocksize) * b('\000') + s
|
s = (blocksize - len(s) % blocksize) * b'\x00' + s
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def bytes_to_long(s):
|
def bytes_to_long(s):
|
||||||
|
|
@ -416,11 +417,17 @@ def bytes_to_long(s):
|
||||||
This is (essentially) the inverse of :func:`long_to_bytes`.
|
This is (essentially) the inverse of :func:`long_to_bytes`.
|
||||||
"""
|
"""
|
||||||
acc = 0
|
acc = 0
|
||||||
|
|
||||||
unpack = struct.unpack
|
unpack = struct.unpack
|
||||||
|
|
||||||
|
# Up to Python 2.7.3, struct.unpack can't work with bytearrays
|
||||||
|
if sys.version_info[0] < 3 and isinstance(s, bytearray):
|
||||||
|
s = bytes(s)
|
||||||
|
|
||||||
length = len(s)
|
length = len(s)
|
||||||
if length % 4:
|
if length % 4:
|
||||||
extra = (4 - length % 4)
|
extra = (4 - length % 4)
|
||||||
s = b('\000') * extra + s
|
s = b'\x00' * extra + s
|
||||||
length = length + extra
|
length = length + extra
|
||||||
for i in range(0, length, 4):
|
for i in range(0, length, 4):
|
||||||
acc = (acc << 32) + unpack('>I', s[i:i+4])[0]
|
acc = (acc << 32) + unpack('>I', s[i:i+4])[0]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue