Various small fixes and clean ups

This commit is contained in:
Helder Eijs 2018-04-03 11:18:31 +02:00
parent c1834322b6
commit 53be0708d7
5 changed files with 63 additions and 34 deletions

View file

@ -23,7 +23,7 @@
from Crypto.Signature.pss import MGF1
import Crypto.Hash.SHA1
from Crypto.Util.py3compat import *
from Crypto.Util.py3compat import bord, _copy_bytes
import Crypto.Util.number
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
from Crypto.Util.strxor import strxor
@ -70,7 +70,7 @@ class PKCS1OAEP_Cipher:
else:
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
self._label = bstr(label)
self._label = _copy_bytes(None, None, label)
self._randfunc = randfunc
def can_encrypt(self):
@ -102,24 +102,23 @@ class PKCS1OAEP_Cipher:
:raises ValueError:
if the message is too long.
"""
# TODO: Verify the key is RSA
# See 7.1.1 in RFC3447
modBits = Crypto.Util.number.size(self._key.n)
k = ceil_div(modBits,8) # Convert from bits to bytes
k = ceil_div(modBits, 8) # Convert from bits to bytes
hLen = self._hashObj.digest_size
mLen = len(message)
# Step 1b
ps_len = k-mLen-2*hLen-2
if ps_len<0:
ps_len = k - mLen - 2 * hLen - 2
if ps_len < 0:
raise ValueError("Plaintext is too long.")
# Step 2a
lHash = self._hashObj.new(self._label).digest()
# Step 2b
ps = bchr(0x00)*ps_len
ps = b'\x00' * ps_len
# Step 2c
db = lHash + ps + bchr(0x01) + message
db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
# Step 2d
ros = self._randfunc(hLen)
# Step 2e
@ -131,7 +130,7 @@ class PKCS1OAEP_Cipher:
# Step 2h
maskedSeed = strxor(ros, seedMask)
# Step 2i
em = bchr(0x00) + maskedSeed + maskedDB
em = b'\x00' + maskedSeed + maskedDB
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
@ -167,7 +166,7 @@ class PKCS1OAEP_Cipher:
if len(ciphertext) != k or k<hLen+2:
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(bstr(ciphertext))
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
m_int = self._key._decrypt(ct_int)
# Complete step 2c (I2OSP)
@ -190,20 +189,20 @@ class PKCS1OAEP_Cipher:
db = strxor(maskedDB, dbMask)
# Step 3g
valid = 1
one = db[hLen:].find(bchr(0x01))
one = db[hLen:].find(b'\x01')
lHash1 = db[:hLen]
if lHash1!=lHash:
valid = 0
if one<0:
valid = 0
if bord(y)!=0:
if bord(y) != 0:
valid = 0
if not valid:
raise ValueError("Incorrect decryption.")
# Step 4
return db[hLen+one+1:]
def new(key, hashAlgo=None, mgfunc=None, label=b(''), randfunc=None):
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
:param key:

View file

@ -23,7 +23,7 @@
__all__ = [ 'new', 'PKCS115_Cipher' ]
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
from Crypto.Util.py3compat import *
from Crypto.Util.py3compat import bord, _copy_bytes
import Crypto.Util.number
from Crypto import Random
@ -79,7 +79,7 @@ class PKCS115_Cipher:
mLen = len(message)
# Step 1
if mLen > k-11:
if mLen > k - 11:
raise ValueError("Plaintext is too long.")
# Step 2a
ps = []
@ -88,10 +88,10 @@ class PKCS115_Cipher:
if bord(new_byte[0]) == 0x00:
continue
ps.append(new_byte)
ps = b("").join(ps)
ps = b"".join(ps)
assert(len(ps) == k - mLen - 3)
# Step 2b
em = b('\x00\x02') + ps + bchr(0x00) + bstr(message)
em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
@ -164,17 +164,17 @@ class PKCS115_Cipher:
if len(ciphertext) != k:
raise ValueError("Ciphertext with incorrect length.")
# Step 2a (O2SIP)
ct_int = bytes_to_long(bstr(ciphertext))
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
m_int = self._key._decrypt(ct_int)
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3
sep = em.find(bchr(0x00),2)
if not em.startswith(b('\x00\x02')) or sep<10:
sep = em.find(b'\x00', 2)
if not em.startswith(b'\x00\x02') or sep < 10:
return sentinel
# Step 4
return em[sep+1:]
return em[sep + 1:]
def new(key, randfunc=None):

View file

@ -160,14 +160,14 @@ HKukWBcq9f/UOmS0oEhai/6g+Uf7VHJdWaeO5LzuvwU=
self.assertEqual(pt,pt2)
def testByteArray(self):
pt = b("XER")
pt = b"XER"
cipher = PKCS.new(self.key1024)
ct = cipher.encrypt(bytearray(pt))
pt2 = cipher.decrypt(bytearray(ct), "---")
self.assertEqual(pt, pt2)
def testMemoryview(self):
pt = b("XER")
pt = b"XER"
cipher = PKCS.new(self.key1024)
ct = cipher.encrypt(memoryview(bytearray(pt)))
pt2 = cipher.decrypt(memoryview(bytearray(ct)), "---")

View file

@ -34,6 +34,7 @@
import unittest
from binascii import unhexlify, hexlify
from Crypto.Util.py3compat import _memoryview
from Crypto.SelfTest.st_common import list_test_cases
from Crypto.Util.strxor import strxor, strxor_c
@ -61,16 +62,27 @@ class StrxorTests(unittest.TestCase):
term2 = unhexlify(b"ff339a83e5cd4cdf564990")
self.assertRaises(ValueError, strxor, term1, term2)
def test_bytearray_memoryview(self):
def test_bytearray(self):
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
term1_ba = bytearray(term1)
term1_mv = memoryview(term1)
term2 = unhexlify(b"383d4ba020573314395b")
result = unhexlify(b"c70ed123c59a7fcb6f12")
self.assertEqual(strxor(term1_ba, term2), result)
def test_memoryview(self):
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
term1_mv = memoryview(term1)
term2 = unhexlify(b"383d4ba020573314395b")
result = unhexlify(b"c70ed123c59a7fcb6f12")
self.assertEqual(strxor(term1_mv, term2), result)
import types
if _memoryview is types.NoneType:
del test_memoryview
class Strxor_cTests(unittest.TestCase):
def test1(self):
@ -90,21 +102,32 @@ class Strxor_cTests(unittest.TestCase):
self.assertRaises(ValueError, strxor_c, term1, -1)
self.assertRaises(ValueError, strxor_c, term1, 256)
def test_bytearray_memoryview(self):
def test_bytearray(self):
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
term1_ba = bytearray(term1)
term1_mv = memoryview(term1)
result = unhexlify(b"be72dbc2a48c0d9e1708")
self.assertEqual(strxor_c(term1_ba, 65), result)
def test_memoryview(self):
term1 = unhexlify(b"ff339a83e5cd4cdf5649")
term1_mv = memoryview(term1)
result = unhexlify(b"be72dbc2a48c0d9e1708")
self.assertEqual(strxor_c(term1_mv, 65), result)
import types
if _memoryview is types.NoneType:
del test_memoryview
def get_tests(config={}):
tests = []
tests += list_test_cases(StrxorTests)
tests += list_test_cases(Strxor_cTests)
return tests
if __name__ == '__main__':
suite = lambda: unittest.TestSuite(get_tests())
unittest.main(defaultTest='suite')

View file

@ -26,6 +26,7 @@
import math
import sys
import struct
from Crypto import Random
from Crypto.Util.py3compat import *
@ -60,8 +61,8 @@ def getRandomInteger(N, randfunc=None):
S = randfunc(N>>3)
odd_bits = N % 8
if odd_bits != 0:
char = ord(randfunc(1)) >> (8-odd_bits)
S = bchr(char) + S
rand_bits = ord(randfunc(1)) >> (8-odd_bits)
S = struct.pack('B', rand_bits) + S
value = bytes_to_long(S)
return value
@ -380,7 +381,7 @@ def long_to_bytes(n, blocksize=0):
be of minimal length.
"""
# after much testing, this algorithm was deemed to be the fastest
s = b('')
s = b''
n = int(n)
pack = struct.pack
while n > 0:
@ -388,17 +389,17 @@ def long_to_bytes(n, blocksize=0):
n = n >> 32
# strip off leading zeros
for i in range(len(s)):
if s[i] != b('\000')[0]:
if s[i] != b'\x00'[0]:
break
else:
# only happens when n == 0
s = b('\000')
s = b'\x00'
i = 0
s = s[i:]
# add back some pad bytes. this could be done more efficiently w.r.t. the
# de-padding being done above, but sigh...
if blocksize > 0 and len(s) % blocksize:
s = (blocksize - len(s) % blocksize) * b('\000') + s
s = (blocksize - len(s) % blocksize) * b'\x00' + s
return s
def bytes_to_long(s):
@ -416,11 +417,17 @@ def bytes_to_long(s):
This is (essentially) the inverse of :func:`long_to_bytes`.
"""
acc = 0
unpack = struct.unpack
# Up to Python 2.7.3, struct.unpack can't work with bytearrays
if sys.version_info[0] < 3 and isinstance(s, bytearray):
s = bytes(s)
length = len(s)
if length % 4:
extra = (4 - length % 4)
s = b('\000') * extra + s
s = b'\x00' * extra + s
length = length + extra
for i in range(0, length, 4):
acc = (acc << 32) + unpack('>I', s[i:i+4])[0]