mirror of
https://github.com/Legrandin/pycryptodome.git
synced 2025-12-08 05:19:46 +00:00
Various small fixes and clean ups
This commit is contained in:
parent
c1834322b6
commit
53be0708d7
5 changed files with 63 additions and 34 deletions
|
|
@ -23,7 +23,7 @@
|
|||
from Crypto.Signature.pss import MGF1
|
||||
import Crypto.Hash.SHA1
|
||||
|
||||
from Crypto.Util.py3compat import *
|
||||
from Crypto.Util.py3compat import bord, _copy_bytes
|
||||
import Crypto.Util.number
|
||||
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
||||
from Crypto.Util.strxor import strxor
|
||||
|
|
@ -70,7 +70,7 @@ class PKCS1OAEP_Cipher:
|
|||
else:
|
||||
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
|
||||
|
||||
self._label = bstr(label)
|
||||
self._label = _copy_bytes(None, None, label)
|
||||
self._randfunc = randfunc
|
||||
|
||||
def can_encrypt(self):
|
||||
|
|
@ -102,7 +102,6 @@ class PKCS1OAEP_Cipher:
|
|||
:raises ValueError:
|
||||
if the message is too long.
|
||||
"""
|
||||
# TODO: Verify the key is RSA
|
||||
|
||||
# See 7.1.1 in RFC3447
|
||||
modBits = Crypto.Util.number.size(self._key.n)
|
||||
|
|
@ -117,9 +116,9 @@ class PKCS1OAEP_Cipher:
|
|||
# Step 2a
|
||||
lHash = self._hashObj.new(self._label).digest()
|
||||
# Step 2b
|
||||
ps = bchr(0x00)*ps_len
|
||||
ps = b'\x00' * ps_len
|
||||
# Step 2c
|
||||
db = lHash + ps + bchr(0x01) + message
|
||||
db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
|
||||
# Step 2d
|
||||
ros = self._randfunc(hLen)
|
||||
# Step 2e
|
||||
|
|
@ -131,7 +130,7 @@ class PKCS1OAEP_Cipher:
|
|||
# Step 2h
|
||||
maskedSeed = strxor(ros, seedMask)
|
||||
# Step 2i
|
||||
em = bchr(0x00) + maskedSeed + maskedDB
|
||||
em = b'\x00' + maskedSeed + maskedDB
|
||||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
|
|
@ -167,7 +166,7 @@ class PKCS1OAEP_Cipher:
|
|||
if len(ciphertext) != k or k<hLen+2:
|
||||
raise ValueError("Ciphertext with incorrect length.")
|
||||
# Step 2a (O2SIP)
|
||||
ct_int = bytes_to_long(bstr(ciphertext))
|
||||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
m_int = self._key._decrypt(ct_int)
|
||||
# Complete step 2c (I2OSP)
|
||||
|
|
@ -190,7 +189,7 @@ class PKCS1OAEP_Cipher:
|
|||
db = strxor(maskedDB, dbMask)
|
||||
# Step 3g
|
||||
valid = 1
|
||||
one = db[hLen:].find(bchr(0x01))
|
||||
one = db[hLen:].find(b'\x01')
|
||||
lHash1 = db[:hLen]
|
||||
if lHash1!=lHash:
|
||||
valid = 0
|
||||
|
|
@ -203,7 +202,7 @@ class PKCS1OAEP_Cipher:
|
|||
# Step 4
|
||||
return db[hLen+one+1:]
|
||||
|
||||
def new(key, hashAlgo=None, mgfunc=None, label=b(''), randfunc=None):
|
||||
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
|
||||
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
|
||||
|
||||
:param key:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
__all__ = [ 'new', 'PKCS115_Cipher' ]
|
||||
|
||||
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
|
||||
from Crypto.Util.py3compat import *
|
||||
from Crypto.Util.py3compat import bord, _copy_bytes
|
||||
import Crypto.Util.number
|
||||
from Crypto import Random
|
||||
|
||||
|
|
@ -88,10 +88,10 @@ class PKCS115_Cipher:
|
|||
if bord(new_byte[0]) == 0x00:
|
||||
continue
|
||||
ps.append(new_byte)
|
||||
ps = b("").join(ps)
|
||||
ps = b"".join(ps)
|
||||
assert(len(ps) == k - mLen - 3)
|
||||
# Step 2b
|
||||
em = b('\x00\x02') + ps + bchr(0x00) + bstr(message)
|
||||
em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
|
||||
# Step 3a (OS2IP)
|
||||
em_int = bytes_to_long(em)
|
||||
# Step 3b (RSAEP)
|
||||
|
|
@ -164,14 +164,14 @@ class PKCS115_Cipher:
|
|||
if len(ciphertext) != k:
|
||||
raise ValueError("Ciphertext with incorrect length.")
|
||||
# Step 2a (O2SIP)
|
||||
ct_int = bytes_to_long(bstr(ciphertext))
|
||||
ct_int = bytes_to_long(ciphertext)
|
||||
# Step 2b (RSADP)
|
||||
m_int = self._key._decrypt(ct_int)
|
||||
# Complete step 2c (I2OSP)
|
||||
em = long_to_bytes(m_int, k)
|
||||
# Step 3
|
||||
sep = em.find(bchr(0x00),2)
|
||||
if not em.startswith(b('\x00\x02')) or sep<10:
|
||||
sep = em.find(b'\x00', 2)
|
||||
if not em.startswith(b'\x00\x02') or sep < 10:
|
||||
return sentinel
|
||||
# Step 4
|
||||
return em[sep + 1:]
|
||||
|
|
|
|||
|
|
@ -160,14 +160,14 @@ HKukWBcq9f/UOmS0oEhai/6g+Uf7VHJdWaeO5LzuvwU=
|
|||
self.assertEqual(pt,pt2)
|
||||
|
||||
def testByteArray(self):
|
||||
pt = b("XER")
|
||||
pt = b"XER"
|
||||
cipher = PKCS.new(self.key1024)
|
||||
ct = cipher.encrypt(bytearray(pt))
|
||||
pt2 = cipher.decrypt(bytearray(ct), "---")
|
||||
self.assertEqual(pt, pt2)
|
||||
|
||||
def testMemoryview(self):
|
||||
pt = b("XER")
|
||||
pt = b"XER"
|
||||
cipher = PKCS.new(self.key1024)
|
||||
ct = cipher.encrypt(memoryview(bytearray(pt)))
|
||||
pt2 = cipher.decrypt(memoryview(bytearray(ct)), "---")
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@
|
|||
import unittest
|
||||
from binascii import unhexlify, hexlify
|
||||
|
||||
from Crypto.Util.py3compat import _memoryview
|
||||
from Crypto.SelfTest.st_common import list_test_cases
|
||||
from Crypto.Util.strxor import strxor, strxor_c
|
||||
|
||||
|
|
@ -61,16 +62,27 @@ class StrxorTests(unittest.TestCase):
|
|||
term2 = unhexlify(b"ff339a83e5cd4cdf564990")
|
||||
self.assertRaises(ValueError, strxor, term1, term2)
|
||||
|
||||
def test_bytearray_memoryview(self):
|
||||
def test_bytearray(self):
|
||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||
term1_ba = bytearray(term1)
|
||||
term1_mv = memoryview(term1)
|
||||
term2 = unhexlify(b"383d4ba020573314395b")
|
||||
result = unhexlify(b"c70ed123c59a7fcb6f12")
|
||||
|
||||
self.assertEqual(strxor(term1_ba, term2), result)
|
||||
|
||||
def test_memoryview(self):
|
||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||
term1_mv = memoryview(term1)
|
||||
term2 = unhexlify(b"383d4ba020573314395b")
|
||||
result = unhexlify(b"c70ed123c59a7fcb6f12")
|
||||
|
||||
self.assertEqual(strxor(term1_mv, term2), result)
|
||||
|
||||
import types
|
||||
if _memoryview is types.NoneType:
|
||||
del test_memoryview
|
||||
|
||||
|
||||
class Strxor_cTests(unittest.TestCase):
|
||||
|
||||
def test1(self):
|
||||
|
|
@ -90,21 +102,32 @@ class Strxor_cTests(unittest.TestCase):
|
|||
self.assertRaises(ValueError, strxor_c, term1, -1)
|
||||
self.assertRaises(ValueError, strxor_c, term1, 256)
|
||||
|
||||
def test_bytearray_memoryview(self):
|
||||
def test_bytearray(self):
|
||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||
term1_ba = bytearray(term1)
|
||||
term1_mv = memoryview(term1)
|
||||
result = unhexlify(b"be72dbc2a48c0d9e1708")
|
||||
|
||||
self.assertEqual(strxor_c(term1_ba, 65), result)
|
||||
|
||||
def test_memoryview(self):
|
||||
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
|
||||
term1_mv = memoryview(term1)
|
||||
result = unhexlify(b"be72dbc2a48c0d9e1708")
|
||||
|
||||
self.assertEqual(strxor_c(term1_mv, 65), result)
|
||||
|
||||
import types
|
||||
if _memoryview is types.NoneType:
|
||||
del test_memoryview
|
||||
|
||||
|
||||
def get_tests(config={}):
|
||||
tests = []
|
||||
tests += list_test_cases(StrxorTests)
|
||||
tests += list_test_cases(Strxor_cTests)
|
||||
return tests
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
suite = lambda: unittest.TestSuite(get_tests())
|
||||
unittest.main(defaultTest='suite')
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
import math
|
||||
import sys
|
||||
import struct
|
||||
from Crypto import Random
|
||||
from Crypto.Util.py3compat import *
|
||||
|
||||
|
|
@ -60,8 +61,8 @@ def getRandomInteger(N, randfunc=None):
|
|||
S = randfunc(N>>3)
|
||||
odd_bits = N % 8
|
||||
if odd_bits != 0:
|
||||
char = ord(randfunc(1)) >> (8-odd_bits)
|
||||
S = bchr(char) + S
|
||||
rand_bits = ord(randfunc(1)) >> (8-odd_bits)
|
||||
S = struct.pack('B', rand_bits) + S
|
||||
value = bytes_to_long(S)
|
||||
return value
|
||||
|
||||
|
|
@ -380,7 +381,7 @@ def long_to_bytes(n, blocksize=0):
|
|||
be of minimal length.
|
||||
"""
|
||||
# after much testing, this algorithm was deemed to be the fastest
|
||||
s = b('')
|
||||
s = b''
|
||||
n = int(n)
|
||||
pack = struct.pack
|
||||
while n > 0:
|
||||
|
|
@ -388,17 +389,17 @@ def long_to_bytes(n, blocksize=0):
|
|||
n = n >> 32
|
||||
# strip off leading zeros
|
||||
for i in range(len(s)):
|
||||
if s[i] != b('\000')[0]:
|
||||
if s[i] != b'\x00'[0]:
|
||||
break
|
||||
else:
|
||||
# only happens when n == 0
|
||||
s = b('\000')
|
||||
s = b'\x00'
|
||||
i = 0
|
||||
s = s[i:]
|
||||
# add back some pad bytes. this could be done more efficiently w.r.t. the
|
||||
# de-padding being done above, but sigh...
|
||||
if blocksize > 0 and len(s) % blocksize:
|
||||
s = (blocksize - len(s) % blocksize) * b('\000') + s
|
||||
s = (blocksize - len(s) % blocksize) * b'\x00' + s
|
||||
return s
|
||||
|
||||
def bytes_to_long(s):
|
||||
|
|
@ -416,11 +417,17 @@ def bytes_to_long(s):
|
|||
This is (essentially) the inverse of :func:`long_to_bytes`.
|
||||
"""
|
||||
acc = 0
|
||||
|
||||
unpack = struct.unpack
|
||||
|
||||
# Up to Python 2.7.3, struct.unpack can't work with bytearrays
|
||||
if sys.version_info[0] < 3 and isinstance(s, bytearray):
|
||||
s = bytes(s)
|
||||
|
||||
length = len(s)
|
||||
if length % 4:
|
||||
extra = (4 - length % 4)
|
||||
s = b('\000') * extra + s
|
||||
s = b'\x00' * extra + s
|
||||
length = length + extra
|
||||
for i in range(0, length, 4):
|
||||
acc = (acc << 32) + unpack('>I', s[i:i+4])[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue