gh-140911: Ensure that UserString.index() and UserString.rindex() accept UserString as argument (GH-140945)

This commit is contained in:
Krishna Chaitanya 2025-11-25 18:55:46 +05:30 committed by GitHub
parent d07d3a3c57
commit e6174ee981
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 33 deletions

View file

@ -1542,6 +1542,8 @@ def format_map(self, mapping):
return self.data.format_map(mapping)
def index(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.index(sub, start, end)
def isalpha(self):
@ -1610,6 +1612,8 @@ def rfind(self, sub, start=0, end=_sys.maxsize):
return self.data.rfind(sub, start, end)
def rindex(self, sub, start=0, end=_sys.maxsize):
if isinstance(sub, UserString):
sub = sub.data
return self.data.rindex(sub, start, end)
def rjust(self, width, *args):

View file

@ -90,6 +90,18 @@ def checkcall(self, obj, methodname, *args):
args = self.fixtype(args)
getattr(obj, methodname)(*args)
def _get_teststrings(self, charset, digits):
base = len(charset)
teststrings = set()
for i in range(base ** digits):
entry = []
for j in range(digits):
i, m = divmod(i, base)
entry.append(charset[m])
teststrings.add(''.join(entry))
teststrings = [self.fixtype(ts) for ts in teststrings]
return teststrings
def test_count(self):
self.checkequal(3, 'aaa', 'count', 'a')
self.checkequal(0, 'aaa', 'count', 'b')
@ -130,17 +142,7 @@ def test_count(self):
# For a variety of combinations,
# verify that str.count() matches an equivalent function
# replacing all occurrences and then differencing the string lengths
charset = ['', 'a', 'b']
digits = 7
base = len(charset)
teststrings = set()
for i in range(base ** digits):
entry = []
for j in range(digits):
i, m = divmod(i, base)
entry.append(charset[m])
teststrings.add(''.join(entry))
teststrings = [self.fixtype(ts) for ts in teststrings]
teststrings = self._get_teststrings(['', 'a', 'b'], 7)
for i in teststrings:
n = len(i)
for j in teststrings:
@ -197,17 +199,7 @@ def test_find(self):
# For a variety of combinations,
# verify that str.find() matches __contains__
# and that the found substring is really at that location
charset = ['', 'a', 'b', 'c']
digits = 5
base = len(charset)
teststrings = set()
for i in range(base ** digits):
entry = []
for j in range(digits):
i, m = divmod(i, base)
entry.append(charset[m])
teststrings.add(''.join(entry))
teststrings = [self.fixtype(ts) for ts in teststrings]
teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5)
for i in teststrings:
for j in teststrings:
loc = i.find(j)
@ -244,17 +236,7 @@ def test_rfind(self):
# For a variety of combinations,
# verify that str.rfind() matches __contains__
# and that the found substring is really at that location
charset = ['', 'a', 'b', 'c']
digits = 5
base = len(charset)
teststrings = set()
for i in range(base ** digits):
entry = []
for j in range(digits):
i, m = divmod(i, base)
entry.append(charset[m])
teststrings.add(''.join(entry))
teststrings = [self.fixtype(ts) for ts in teststrings]
teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5)
for i in teststrings:
for j in teststrings:
loc = i.rfind(j)
@ -295,6 +277,19 @@ def test_index(self):
else:
self.checkraises(TypeError, 'hello', 'index', 42)
# For a variety of combinations,
# verify that str.index() matches __contains__
# and that the found substring is really at that location
teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5)
for i in teststrings:
for j in teststrings:
if j in i:
loc = i.index(j)
self.assertGreaterEqual(loc, 0)
self.assertEqual(i[loc:loc+len(j)], j)
else:
self.assertRaises(ValueError, i.index, j)
def test_rindex(self):
self.checkequal(12, 'abcdefghiabc', 'rindex', '')
self.checkequal(3, 'abcdefghiabc', 'rindex', 'def')
@ -321,6 +316,19 @@ def test_rindex(self):
else:
self.checkraises(TypeError, 'hello', 'rindex', 42)
# For a variety of combinations,
# verify that str.rindex() matches __contains__
# and that the found substring is really at that location
teststrings = self._get_teststrings(['', 'a', 'b', 'c'], 5)
for i in teststrings:
for j in teststrings:
if j in i:
loc = i.rindex(j)
self.assertGreaterEqual(loc, 0)
self.assertEqual(i[loc:loc+len(j)], j)
else:
self.assertRaises(ValueError, i.rindex, j)
def test_find_periodic_pattern(self):
"""Cover the special path for periodic patterns."""
def reference_find(p, s):

View file

@ -0,0 +1,3 @@
:mod:`collections`: Ensure that the methods ``UserString.rindex()`` and
``UserString.index()`` accept :class:`collections.UserString` instances as the
sub argument.