mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	Close #15573: use value-based memoryview comparisons (patch by Stefan Krah)
This commit is contained in:
		
							parent
							
								
									5c0b1ca55e
								
							
						
					
					
						commit
						06e1ab0a6b
					
				
					 5 changed files with 778 additions and 132 deletions
				
			
		|  | @ -246,7 +246,7 @@ Create a new memoryview object which references the given object."); | |||
|     (view->suboffsets && view->suboffsets[dest->ndim-1] >= 0) | ||||
| 
 | ||||
| Py_LOCAL_INLINE(int) | ||||
| last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src) | ||||
| last_dim_is_contiguous(const Py_buffer *dest, const Py_buffer *src) | ||||
| { | ||||
|     assert(dest->ndim > 0 && src->ndim > 0); | ||||
|     return (!HAVE_SUBOFFSETS_IN_LAST_DIM(dest) && | ||||
|  | @ -255,37 +255,63 @@ last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src) | |||
|             src->strides[src->ndim-1] == src->itemsize); | ||||
| } | ||||
| 
 | ||||
| /* Check that the logical structure of the destination and source buffers
 | ||||
|    is identical. */ | ||||
| static int | ||||
| cmp_structure(Py_buffer *dest, Py_buffer *src) | ||||
| /* This is not a general function for determining format equivalence.
 | ||||
|    It is used in copy_single() and copy_buffer() to weed out non-matching | ||||
|    formats. Skipping the '@' character is specifically used in slice | ||||
|    assignments, where the lvalue is already known to have a single character | ||||
|    format. This is a performance hack that could be rewritten (if properly | ||||
|    benchmarked). */ | ||||
| Py_LOCAL_INLINE(int) | ||||
| equiv_format(const Py_buffer *dest, const Py_buffer *src) | ||||
| { | ||||
|     const char *dfmt, *sfmt; | ||||
|     int i; | ||||
| 
 | ||||
|     assert(dest->format && src->format); | ||||
|     dfmt = dest->format[0] == '@' ? dest->format+1 : dest->format; | ||||
|     sfmt = src->format[0] == '@' ? src->format+1 : src->format; | ||||
| 
 | ||||
|     if (strcmp(dfmt, sfmt) != 0 || | ||||
|         dest->itemsize != src->itemsize || | ||||
|         dest->ndim != src->ndim) { | ||||
|         goto value_error; | ||||
|         dest->itemsize != src->itemsize) { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| /* Two shapes are equivalent if they are either equal or identical up
 | ||||
|    to a zero element at the same position. For example, in NumPy arrays | ||||
|    the shapes [1, 0, 5] and [1, 0, 7] are equivalent. */ | ||||
| Py_LOCAL_INLINE(int) | ||||
| equiv_shape(const Py_buffer *dest, const Py_buffer *src) | ||||
| { | ||||
|     int i; | ||||
| 
 | ||||
|     if (dest->ndim != src->ndim) | ||||
|         return 0; | ||||
| 
 | ||||
|     for (i = 0; i < dest->ndim; i++) { | ||||
|         if (dest->shape[i] != src->shape[i]) | ||||
|             goto value_error; | ||||
|             return 0; | ||||
|         if (dest->shape[i] == 0) | ||||
|             break; | ||||
|     } | ||||
| 
 | ||||
|     return 0; | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| value_error: | ||||
|     PyErr_SetString(PyExc_ValueError, | ||||
|         "ndarray assignment: lvalue and rvalue have different structures"); | ||||
|     return -1; | ||||
| /* Check that the logical structure of the destination and source buffers
 | ||||
|    is identical. */ | ||||
| static int | ||||
| equiv_structure(const Py_buffer *dest, const Py_buffer *src) | ||||
| { | ||||
|     if (!equiv_format(dest, src) || | ||||
|         !equiv_shape(dest, src)) { | ||||
|         PyErr_SetString(PyExc_ValueError, | ||||
|             "ndarray assignment: lvalue and rvalue have different structures"); | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     return 1; | ||||
| } | ||||
| 
 | ||||
| /* Base case for recursive multi-dimensional copying. Contiguous arrays are
 | ||||
|  | @ -358,7 +384,7 @@ copy_single(Py_buffer *dest, Py_buffer *src) | |||
| 
 | ||||
|     assert(dest->ndim == 1); | ||||
| 
 | ||||
|     if (cmp_structure(dest, src) < 0) | ||||
|     if (!equiv_structure(dest, src)) | ||||
|         return -1; | ||||
| 
 | ||||
|     if (!last_dim_is_contiguous(dest, src)) { | ||||
|  | @ -390,7 +416,7 @@ copy_buffer(Py_buffer *dest, Py_buffer *src) | |||
| 
 | ||||
|     assert(dest->ndim > 0); | ||||
| 
 | ||||
|     if (cmp_structure(dest, src) < 0) | ||||
|     if (!equiv_structure(dest, src)) | ||||
|         return -1; | ||||
| 
 | ||||
|     if (!last_dim_is_contiguous(dest, src)) { | ||||
|  | @ -1827,6 +1853,131 @@ pack_single(char *ptr, PyObject *item, const char *fmt) | |||
| } | ||||
| 
 | ||||
| 
 | ||||
| /****************************************************************************/ | ||||
| /*                       unpack using the struct module                     */ | ||||
| /****************************************************************************/ | ||||
| 
 | ||||
| /* For reasonable performance it is necessary to cache all objects required
 | ||||
|    for unpacking. An unpacker can handle the format passed to unpack_from(). | ||||
|    Invariant: All pointer fields of the struct should either be NULL or valid | ||||
|    pointers. */ | ||||
| struct unpacker { | ||||
|     PyObject *unpack_from; /* Struct.unpack_from(format) */ | ||||
|     PyObject *mview;       /* cached memoryview */ | ||||
|     char *item;            /* buffer for mview */ | ||||
|     Py_ssize_t itemsize;   /* len(item) */ | ||||
| }; | ||||
| 
 | ||||
| static struct unpacker * | ||||
| unpacker_new(void) | ||||
| { | ||||
|     struct unpacker *x = PyMem_Malloc(sizeof *x); | ||||
| 
 | ||||
|     if (x == NULL) { | ||||
|         PyErr_NoMemory(); | ||||
|         return NULL; | ||||
|     } | ||||
| 
 | ||||
|     x->unpack_from = NULL; | ||||
|     x->mview = NULL; | ||||
|     x->item = NULL; | ||||
|     x->itemsize = 0; | ||||
| 
 | ||||
|     return x; | ||||
| } | ||||
| 
 | ||||
| static void | ||||
| unpacker_free(struct unpacker *x) | ||||
| { | ||||
|     if (x) { | ||||
|         Py_XDECREF(x->unpack_from); | ||||
|         Py_XDECREF(x->mview); | ||||
|         PyMem_Free(x->item); | ||||
|         PyMem_Free(x); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /* Return a new unpacker for the given format. */ | ||||
| static struct unpacker * | ||||
| struct_get_unpacker(const char *fmt, Py_ssize_t itemsize) | ||||
| { | ||||
|     PyObject *structmodule;     /* XXX cache these two */ | ||||
|     PyObject *Struct = NULL;    /* XXX in globals?     */ | ||||
|     PyObject *structobj = NULL; | ||||
|     PyObject *format = NULL; | ||||
|     struct unpacker *x = NULL; | ||||
| 
 | ||||
|     structmodule = PyImport_ImportModule("struct"); | ||||
|     if (structmodule == NULL) | ||||
|         return NULL; | ||||
| 
 | ||||
|     Struct = PyObject_GetAttrString(structmodule, "Struct"); | ||||
|     Py_DECREF(structmodule); | ||||
|     if (Struct == NULL) | ||||
|         return NULL; | ||||
| 
 | ||||
|     x = unpacker_new(); | ||||
|     if (x == NULL) | ||||
|         goto error; | ||||
| 
 | ||||
|     format = PyBytes_FromString(fmt); | ||||
|     if (format == NULL) | ||||
|         goto error; | ||||
| 
 | ||||
|     structobj = PyObject_CallFunctionObjArgs(Struct, format, NULL); | ||||
|     if (structobj == NULL) | ||||
|         goto error; | ||||
| 
 | ||||
|     x->unpack_from = PyObject_GetAttrString(structobj, "unpack_from"); | ||||
|     if (x->unpack_from == NULL) | ||||
|         goto error; | ||||
| 
 | ||||
|     x->item = PyMem_Malloc(itemsize); | ||||
|     if (x->item == NULL) { | ||||
|         PyErr_NoMemory(); | ||||
|         goto error; | ||||
|     } | ||||
|     x->itemsize = itemsize; | ||||
| 
 | ||||
|     x->mview = PyMemoryView_FromMemory(x->item, itemsize, PyBUF_WRITE); | ||||
|     if (x->mview == NULL) | ||||
|         goto error; | ||||
| 
 | ||||
| 
 | ||||
| out: | ||||
|     Py_XDECREF(Struct); | ||||
|     Py_XDECREF(format); | ||||
|     Py_XDECREF(structobj); | ||||
|     return x; | ||||
| 
 | ||||
| error: | ||||
|     unpacker_free(x); | ||||
|     x = NULL; | ||||
|     goto out; | ||||
| } | ||||
| 
 | ||||
| /* unpack a single item */ | ||||
| static PyObject * | ||||
| struct_unpack_single(const char *ptr, struct unpacker *x) | ||||
| { | ||||
|     PyObject *v; | ||||
| 
 | ||||
|     memcpy(x->item, ptr, x->itemsize); | ||||
|     v = PyObject_CallFunctionObjArgs(x->unpack_from, x->mview, NULL); | ||||
|     if (v == NULL) | ||||
|         return NULL; | ||||
| 
 | ||||
|     if (PyTuple_GET_SIZE(v) == 1) { | ||||
|         PyObject *tmp = PyTuple_GET_ITEM(v, 0); | ||||
|         Py_INCREF(tmp); | ||||
|         Py_DECREF(v); | ||||
|         return tmp; | ||||
|     } | ||||
| 
 | ||||
|     return v; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| /****************************************************************************/ | ||||
| /*                              Representations                             */ | ||||
| /****************************************************************************/ | ||||
|  | @ -2261,6 +2412,58 @@ static PySequenceMethods memory_as_sequence = { | |||
| /*                             Comparisons                                */ | ||||
| /**************************************************************************/ | ||||
| 
 | ||||
| #define MV_COMPARE_EX -1       /* exception */ | ||||
| #define MV_COMPARE_NOT_IMPL -2 /* not implemented */ | ||||
| 
 | ||||
| /* Translate a StructError to "not equal". Preserve other exceptions. */ | ||||
| static int | ||||
| fix_struct_error_int(void) | ||||
| { | ||||
|     assert(PyErr_Occurred()); | ||||
|     /* XXX Cannot get at StructError directly? */ | ||||
|     if (PyErr_ExceptionMatches(PyExc_ImportError) || | ||||
|         PyErr_ExceptionMatches(PyExc_MemoryError)) { | ||||
|         return MV_COMPARE_EX; | ||||
|     } | ||||
|     /* StructError: invalid or unknown format -> not equal */ | ||||
|     PyErr_Clear(); | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
| /* Unpack and compare single items of p and q using the struct module. */ | ||||
| static int | ||||
| struct_unpack_cmp(const char *p, const char *q, | ||||
|                   struct unpacker *unpack_p, struct unpacker *unpack_q) | ||||
| { | ||||
|     PyObject *v, *w; | ||||
|     int ret; | ||||
| 
 | ||||
|     /* At this point any exception from the struct module should not be
 | ||||
|        StructError, since both formats have been accepted already. */ | ||||
|     v = struct_unpack_single(p, unpack_p); | ||||
|     if (v == NULL) | ||||
|         return MV_COMPARE_EX; | ||||
| 
 | ||||
|     w = struct_unpack_single(q, unpack_q); | ||||
|     if (w == NULL) { | ||||
|         Py_DECREF(v); | ||||
|         return MV_COMPARE_EX; | ||||
|     } | ||||
| 
 | ||||
|     /* MV_COMPARE_EX == -1: exceptions are preserved */ | ||||
|     ret = PyObject_RichCompareBool(v, w, Py_EQ); | ||||
|     Py_DECREF(v); | ||||
|     Py_DECREF(w); | ||||
| 
 | ||||
|     return ret; | ||||
| } | ||||
| 
 | ||||
| /* Unpack and compare single items of p and q. If both p and q have the same
 | ||||
|    single element native format, the comparison uses a fast path (gcc creates | ||||
|    a jump table and converts memcpy into simple assignments on x86/x64). | ||||
| 
 | ||||
|    Otherwise, the comparison is delegated to the struct module, which is | ||||
|    30-60x slower. */ | ||||
| #define CMP_SINGLE(p, q, type) \ | ||||
|     do {                                 \ | ||||
|         type x;                          \ | ||||
|  | @ -2271,11 +2474,12 @@ static PySequenceMethods memory_as_sequence = { | |||
|     } while (0) | ||||
| 
 | ||||
| Py_LOCAL_INLINE(int) | ||||
| unpack_cmp(const char *p, const char *q, const char *fmt) | ||||
| unpack_cmp(const char *p, const char *q, char fmt, | ||||
|            struct unpacker *unpack_p, struct unpacker *unpack_q) | ||||
| { | ||||
|     int equal; | ||||
| 
 | ||||
|     switch (fmt[0]) { | ||||
|     switch (fmt) { | ||||
| 
 | ||||
|     /* signed integers and fast path for 'B' */ | ||||
|     case 'B': return *((unsigned char *)p) == *((unsigned char *)q); | ||||
|  | @ -2317,9 +2521,17 @@ unpack_cmp(const char *p, const char *q, const char *fmt) | |||
|     /* pointer */ | ||||
|     case 'P': CMP_SINGLE(p, q, void *); return equal; | ||||
| 
 | ||||
|     /* Py_NotImplemented */ | ||||
|     default: return -1; | ||||
|     /* use the struct module */ | ||||
|     case '_': | ||||
|         assert(unpack_p); | ||||
|         assert(unpack_q); | ||||
|         return struct_unpack_cmp(p, q, unpack_p, unpack_q); | ||||
|     } | ||||
| 
 | ||||
|     /* NOT REACHED */ | ||||
|     PyErr_SetString(PyExc_RuntimeError, | ||||
|         "memoryview: internal error in richcompare"); | ||||
|     return MV_COMPARE_EX; | ||||
| } | ||||
| 
 | ||||
| /* Base case for recursive array comparisons. Assumption: ndim == 1. */ | ||||
|  | @ -2327,7 +2539,7 @@ static int | |||
| cmp_base(const char *p, const char *q, const Py_ssize_t *shape, | ||||
|          const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets, | ||||
|          const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets, | ||||
|          const char *fmt) | ||||
|          char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     int equal; | ||||
|  | @ -2335,7 +2547,7 @@ cmp_base(const char *p, const char *q, const Py_ssize_t *shape, | |||
|     for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) { | ||||
|         const char *xp = ADJUST_PTR(p, psuboffsets); | ||||
|         const char *xq = ADJUST_PTR(q, qsuboffsets); | ||||
|         equal = unpack_cmp(xp, xq, fmt); | ||||
|         equal = unpack_cmp(xp, xq, fmt, unpack_p, unpack_q); | ||||
|         if (equal <= 0) | ||||
|             return equal; | ||||
|     } | ||||
|  | @ -2350,7 +2562,7 @@ cmp_rec(const char *p, const char *q, | |||
|         Py_ssize_t ndim, const Py_ssize_t *shape, | ||||
|         const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets, | ||||
|         const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets, | ||||
|         const char *fmt) | ||||
|         char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q) | ||||
| { | ||||
|     Py_ssize_t i; | ||||
|     int equal; | ||||
|  | @ -2364,7 +2576,7 @@ cmp_rec(const char *p, const char *q, | |||
|         return cmp_base(p, q, shape, | ||||
|                         pstrides, psuboffsets, | ||||
|                         qstrides, qsuboffsets, | ||||
|                         fmt); | ||||
|                         fmt, unpack_p, unpack_q); | ||||
|     } | ||||
| 
 | ||||
|     for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) { | ||||
|  | @ -2373,7 +2585,7 @@ cmp_rec(const char *p, const char *q, | |||
|         equal = cmp_rec(xp, xq, ndim-1, shape+1, | ||||
|                         pstrides+1, psuboffsets ? psuboffsets+1 : NULL, | ||||
|                         qstrides+1, qsuboffsets ? qsuboffsets+1 : NULL, | ||||
|                         fmt); | ||||
|                         fmt, unpack_p, unpack_q); | ||||
|         if (equal <= 0) | ||||
|             return equal; | ||||
|     } | ||||
|  | @ -2385,9 +2597,12 @@ static PyObject * | |||
| memory_richcompare(PyObject *v, PyObject *w, int op) | ||||
| { | ||||
|     PyObject *res; | ||||
|     Py_buffer wbuf, *vv, *ww = NULL; | ||||
|     const char *vfmt, *wfmt; | ||||
|     int equal = -1; /* Py_NotImplemented */ | ||||
|     Py_buffer wbuf, *vv; | ||||
|     Py_buffer *ww = NULL; | ||||
|     struct unpacker *unpack_v = NULL; | ||||
|     struct unpacker *unpack_w = NULL; | ||||
|     char vfmt, wfmt; | ||||
|     int equal = MV_COMPARE_NOT_IMPL; | ||||
| 
 | ||||
|     if (op != Py_EQ && op != Py_NE) | ||||
|         goto result; /* Py_NotImplemented */ | ||||
|  | @ -2414,38 +2629,59 @@ memory_richcompare(PyObject *v, PyObject *w, int op) | |||
|         ww = &wbuf; | ||||
|     } | ||||
| 
 | ||||
|     vfmt = adjust_fmt(vv); | ||||
|     wfmt = adjust_fmt(ww); | ||||
|     if (vfmt == NULL || wfmt == NULL) { | ||||
|         PyErr_Clear(); | ||||
|         goto result; /* Py_NotImplemented */ | ||||
|     } | ||||
| 
 | ||||
|     if (cmp_structure(vv, ww) < 0) { | ||||
|     if (!equiv_shape(vv, ww)) { | ||||
|         PyErr_Clear(); | ||||
|         equal = 0; | ||||
|         goto result; | ||||
|     } | ||||
| 
 | ||||
|     /* Use fast unpacking for identical primitive C type formats. */ | ||||
|     if (get_native_fmtchar(&vfmt, vv->format) < 0) | ||||
|         vfmt = '_'; | ||||
|     if (get_native_fmtchar(&wfmt, ww->format) < 0) | ||||
|         wfmt = '_'; | ||||
|     if (vfmt == '_' || wfmt == '_' || vfmt != wfmt) { | ||||
|         /* Use struct module unpacking. NOTE: Even for equal format strings,
 | ||||
|            memcmp() cannot be used for item comparison since it would give | ||||
|            incorrect results in the case of NaNs or uninitialized padding | ||||
|            bytes. */ | ||||
|         vfmt = '_'; | ||||
|         unpack_v = struct_get_unpacker(vv->format, vv->itemsize); | ||||
|         if (unpack_v == NULL) { | ||||
|             equal = fix_struct_error_int(); | ||||
|             goto result; | ||||
|         } | ||||
|         unpack_w = struct_get_unpacker(ww->format, ww->itemsize); | ||||
|         if (unpack_w == NULL) { | ||||
|             equal = fix_struct_error_int(); | ||||
|             goto result; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (vv->ndim == 0) { | ||||
|         equal = unpack_cmp(vv->buf, ww->buf, vfmt); | ||||
|         equal = unpack_cmp(vv->buf, ww->buf, | ||||
|                            vfmt, unpack_v, unpack_w); | ||||
|     } | ||||
|     else if (vv->ndim == 1) { | ||||
|         equal = cmp_base(vv->buf, ww->buf, vv->shape, | ||||
|                          vv->strides, vv->suboffsets, | ||||
|                          ww->strides, ww->suboffsets, | ||||
|                          vfmt); | ||||
|                          vfmt, unpack_v, unpack_w); | ||||
|     } | ||||
|     else { | ||||
|         equal = cmp_rec(vv->buf, ww->buf, vv->ndim, vv->shape, | ||||
|                         vv->strides, vv->suboffsets, | ||||
|                         ww->strides, ww->suboffsets, | ||||
|                         vfmt); | ||||
|                         vfmt, unpack_v, unpack_w); | ||||
|     } | ||||
| 
 | ||||
| result: | ||||
|     if (equal < 0) | ||||
|         res = Py_NotImplemented;  | ||||
|     if (equal < 0) { | ||||
|         if (equal == MV_COMPARE_NOT_IMPL) | ||||
|             res = Py_NotImplemented; | ||||
|         else /* exception */ | ||||
|             res = NULL; | ||||
|     } | ||||
|     else if ((equal && op == Py_EQ) || (!equal && op == Py_NE)) | ||||
|         res = Py_True; | ||||
|     else | ||||
|  | @ -2453,7 +2689,11 @@ memory_richcompare(PyObject *v, PyObject *w, int op) | |||
| 
 | ||||
|     if (ww == &wbuf) | ||||
|         PyBuffer_Release(ww); | ||||
|     Py_INCREF(res); | ||||
| 
 | ||||
|     unpacker_free(unpack_v); | ||||
|     unpacker_free(unpack_w); | ||||
| 
 | ||||
|     Py_XINCREF(res); | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Nick Coghlan
						Nick Coghlan