| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | #-*- coding: ISO-8859-1 -*- | 
					
						
							|  |  |  |  | # pysqlite2/test/userfunctions.py: tests for user-defined functions and | 
					
						
							|  |  |  |  | #                                  aggregates. | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # Copyright (C) 2005 Gerhard H<>ring <gh@ghaering.de> | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # This file is part of pysqlite. | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # This software is provided 'as-is', without any express or implied | 
					
						
							|  |  |  |  | # warranty.  In no event will the authors be held liable for any damages | 
					
						
							|  |  |  |  | # arising from the use of this software. | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # Permission is granted to anyone to use this software for any purpose, | 
					
						
							|  |  |  |  | # including commercial applications, and to alter it and redistribute it | 
					
						
							|  |  |  |  | # freely, subject to the following restrictions: | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # 1. The origin of this software must not be misrepresented; you must not | 
					
						
							|  |  |  |  | #    claim that you wrote the original software. If you use this software | 
					
						
							|  |  |  |  | #    in a product, an acknowledgment in the product documentation would be | 
					
						
							|  |  |  |  | #    appreciated but is not required. | 
					
						
							|  |  |  |  | # 2. Altered source versions must be plainly marked as such, and must not be | 
					
						
							|  |  |  |  | #    misrepresented as being the original software. | 
					
						
							|  |  |  |  | # 3. This notice may not be removed or altered from any source distribution. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import unittest | 
					
						
							|  |  |  |  | import sqlite3 as sqlite | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def func_returntext(): | 
					
						
							|  |  |  |  |     return "foo" | 
					
						
							|  |  |  |  | def func_returnunicode(): | 
					
						
							|  |  |  |  |     return u"bar" | 
					
						
							|  |  |  |  | def func_returnint(): | 
					
						
							|  |  |  |  |     return 42 | 
					
						
							|  |  |  |  | def func_returnfloat(): | 
					
						
							|  |  |  |  |     return 3.14 | 
					
						
							|  |  |  |  | def func_returnnull(): | 
					
						
							|  |  |  |  |     return None | 
					
						
							|  |  |  |  | def func_returnblob(): | 
					
						
							|  |  |  |  |     return buffer("blob") | 
					
						
							|  |  |  |  | def func_raiseexception(): | 
					
						
							|  |  |  |  |     5/0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def func_isstring(v): | 
					
						
							|  |  |  |  |     return type(v) is unicode | 
					
						
							|  |  |  |  | def func_isint(v): | 
					
						
							|  |  |  |  |     return type(v) is int | 
					
						
							|  |  |  |  | def func_isfloat(v): | 
					
						
							|  |  |  |  |     return type(v) is float | 
					
						
							|  |  |  |  | def func_isnone(v): | 
					
						
							|  |  |  |  |     return type(v) is type(None) | 
					
						
							|  |  |  |  | def func_isblob(v): | 
					
						
							|  |  |  |  |     return type(v) is buffer | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrNoStep: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         return 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | class AggrNoFinalize: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, x): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrExceptionInInit: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         5/0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, x): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrExceptionInStep: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, x): | 
					
						
							|  |  |  |  |         5/0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         return 42 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrExceptionInFinalize: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, x): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         5/0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrCheckType: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         self.val = None | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, whichType, val): | 
					
						
							|  |  |  |  |         theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer} | 
					
						
							|  |  |  |  |         self.val = int(theType[whichType] is type(val)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         return self.val | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggrSum: | 
					
						
							|  |  |  |  |     def __init__(self): | 
					
						
							|  |  |  |  |         self.val = 0.0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def step(self, val): | 
					
						
							|  |  |  |  |         self.val += val | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def finalize(self): | 
					
						
							|  |  |  |  |         return self.val | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class FunctionTests(unittest.TestCase): | 
					
						
							|  |  |  |  |     def setUp(self): | 
					
						
							|  |  |  |  |         self.con = sqlite.connect(":memory:") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         self.con.create_function("returntext", 0, func_returntext) | 
					
						
							|  |  |  |  |         self.con.create_function("returnunicode", 0, func_returnunicode) | 
					
						
							|  |  |  |  |         self.con.create_function("returnint", 0, func_returnint) | 
					
						
							|  |  |  |  |         self.con.create_function("returnfloat", 0, func_returnfloat) | 
					
						
							|  |  |  |  |         self.con.create_function("returnnull", 0, func_returnnull) | 
					
						
							|  |  |  |  |         self.con.create_function("returnblob", 0, func_returnblob) | 
					
						
							|  |  |  |  |         self.con.create_function("raiseexception", 0, func_raiseexception) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         self.con.create_function("isstring", 1, func_isstring) | 
					
						
							|  |  |  |  |         self.con.create_function("isint", 1, func_isint) | 
					
						
							|  |  |  |  |         self.con.create_function("isfloat", 1, func_isfloat) | 
					
						
							|  |  |  |  |         self.con.create_function("isnone", 1, func_isnone) | 
					
						
							|  |  |  |  |         self.con.create_function("isblob", 1, func_isblob) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def tearDown(self): | 
					
						
							|  |  |  |  |         self.con.close() | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-23 15:24:26 +00:00
										 |  |  |  |     def CheckFuncErrorOnCreate(self): | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             self.con.create_function("bla", -100, lambda x: 2*x) | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError: | 
					
						
							|  |  |  |  |             pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  |     def CheckFuncRefCount(self): | 
					
						
							|  |  |  |  |         def getfunc(): | 
					
						
							|  |  |  |  |             def f(): | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |                 return 1 | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  |             return f | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         f = getfunc() | 
					
						
							|  |  |  |  |         globals()["foo"] = f | 
					
						
							|  |  |  |  |         # self.con.create_function("reftest", 0, getfunc()) | 
					
						
							|  |  |  |  |         self.con.create_function("reftest", 0, f) | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select reftest()") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnText(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returntext()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), unicode) | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, "foo") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnUnicode(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returnunicode()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), unicode) | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, u"bar") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnInt(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returnint()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), int) | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 42) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnFloat(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returnfloat()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), float) | 
					
						
							|  |  |  |  |         if val < 3.139 or val > 3.141: | 
					
						
							|  |  |  |  |             self.fail("wrong value") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnNull(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returnnull()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), type(None)) | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, None) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncReturnBlob(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select returnblob()") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(type(val), buffer) | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, buffer("blob")) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckFuncException(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select raiseexception()") | 
					
						
							|  |  |  |  |             cur.fetchone() | 
					
						
							|  |  |  |  |             self.fail("should have raised OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], 'user-defined function raised exception') | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckParamString(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select isstring(?)", ("foo",)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckParamInt(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select isint(?)", (42,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckParamFloat(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select isfloat(?)", (3.14,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckParamNone(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select isnone(?)", (None,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckParamBlob(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select isblob(?)", (buffer("blob"),)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AggregateTests(unittest.TestCase): | 
					
						
							|  |  |  |  |     def setUp(self): | 
					
						
							|  |  |  |  |         self.con = sqlite.connect(":memory:") | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("""
 | 
					
						
							|  |  |  |  |             create table test( | 
					
						
							|  |  |  |  |                 t text, | 
					
						
							|  |  |  |  |                 i integer, | 
					
						
							|  |  |  |  |                 f float, | 
					
						
							|  |  |  |  |                 n, | 
					
						
							|  |  |  |  |                 b blob | 
					
						
							|  |  |  |  |                 ) | 
					
						
							|  |  |  |  |             """)
 | 
					
						
							|  |  |  |  |         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", | 
					
						
							|  |  |  |  |             ("foo", 5, 3.14, None, buffer("blob"),)) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         self.con.create_aggregate("nostep", 1, AggrNoStep) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("excInit", 1, AggrExceptionInInit) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("excStep", 1, AggrExceptionInStep) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("checkType", 2, AggrCheckType) | 
					
						
							|  |  |  |  |         self.con.create_aggregate("mysum", 1, AggrSum) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def tearDown(self): | 
					
						
							|  |  |  |  |         #self.cur.close() | 
					
						
							|  |  |  |  |         #self.con.close() | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-23 15:24:26 +00:00
										 |  |  |  |     def CheckAggrErrorOnCreate(self): | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             self.con.create_function("bla", -100, AggrSum) | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError: | 
					
						
							|  |  |  |  |             pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  |     def CheckAggrNoStep(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select nostep(t) from test") | 
					
						
							|  |  |  |  |             self.fail("should have raised an AttributeError") | 
					
						
							|  |  |  |  |         except AttributeError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrNoFinalize(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select nofinalize(t) from test") | 
					
						
							|  |  |  |  |             val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrExceptionInInit(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select excInit(t) from test") | 
					
						
							|  |  |  |  |             val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrExceptionInStep(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select excStep(t) from test") | 
					
						
							|  |  |  |  |             val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error") | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrExceptionInFinalize(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |         try: | 
					
						
							|  |  |  |  |             cur.execute("select excFinalize(t) from test") | 
					
						
							|  |  |  |  |             val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |             self.fail("should have raised an OperationalError") | 
					
						
							|  |  |  |  |         except sqlite.OperationalError, e: | 
					
						
							|  |  |  |  |             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckParamStr(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select checkType('str', ?)", ("foo",)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckParamInt(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select checkType('int', ?)", (42,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckParamFloat(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select checkType('float', ?)", (3.14,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckParamNone(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select checkType('None', ?)", (None,)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckParamBlob(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("select checkType('blob', ?)", (buffer("blob"),)) | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 1) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckAggrCheckAggrSum(self): | 
					
						
							|  |  |  |  |         cur = self.con.cursor() | 
					
						
							|  |  |  |  |         cur.execute("delete from test") | 
					
						
							|  |  |  |  |         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) | 
					
						
							|  |  |  |  |         cur.execute("select mysum(i) from test") | 
					
						
							|  |  |  |  |         val = cur.fetchone()[0] | 
					
						
							|  |  |  |  |         self.failUnlessEqual(val, 60) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  | def authorizer_cb(action, arg1, arg2, dbname, source): | 
					
						
							|  |  |  |  |     if action != sqlite.SQLITE_SELECT: | 
					
						
							|  |  |  |  |         return sqlite.SQLITE_DENY | 
					
						
							|  |  |  |  |     if arg2 == 'c2' or arg1 == 't2': | 
					
						
							|  |  |  |  |         return sqlite.SQLITE_DENY | 
					
						
							|  |  |  |  |     return sqlite.SQLITE_OK | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class AuthorizerTests(unittest.TestCase): | 
					
						
							|  |  |  |  |     def setUp(self): | 
					
						
							|  |  |  |  |         sqlite.enable_callback_tracebacks(1) | 
					
						
							|  |  |  |  |         self.con = sqlite.connect(":memory:") | 
					
						
							|  |  |  |  |         self.con.executescript("""
 | 
					
						
							|  |  |  |  |             create table t1 (c1, c2); | 
					
						
							|  |  |  |  |             create table t2 (c1, c2); | 
					
						
							|  |  |  |  |             insert into t1 (c1, c2) values (1, 2); | 
					
						
							|  |  |  |  |             insert into t2 (c1, c2) values (4, 5); | 
					
						
							|  |  |  |  |             """)
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # For our security test: | 
					
						
							|  |  |  |  |         self.con.execute("select c2 from t2") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         self.con.set_authorizer(authorizer_cb) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def tearDown(self): | 
					
						
							|  |  |  |  |         pass | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckTableAccess(self): | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             self.con.execute("select * from t2") | 
					
						
							|  |  |  |  |         except sqlite.DatabaseError, e: | 
					
						
							|  |  |  |  |             if not e.args[0].endswith("prohibited"): | 
					
						
							|  |  |  |  |                 self.fail("wrong exception text: %s" % e.args[0]) | 
					
						
							|  |  |  |  |             return | 
					
						
							|  |  |  |  |         self.fail("should have raised an exception due to missing privileges") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def CheckColumnAccess(self): | 
					
						
							|  |  |  |  |         try: | 
					
						
							|  |  |  |  |             self.con.execute("select c2 from t1") | 
					
						
							|  |  |  |  |         except sqlite.DatabaseError, e: | 
					
						
							|  |  |  |  |             if not e.args[0].endswith("prohibited"): | 
					
						
							|  |  |  |  |                 self.fail("wrong exception text: %s" % e.args[0]) | 
					
						
							|  |  |  |  |             return | 
					
						
							|  |  |  |  |         self.fail("should have raised an exception due to missing privileges") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | def suite(): | 
					
						
							|  |  |  |  |     function_suite = unittest.makeSuite(FunctionTests, "Check") | 
					
						
							|  |  |  |  |     aggregate_suite = unittest.makeSuite(AggregateTests, "Check") | 
					
						
							| 
									
										
										
										
											2006-06-14 04:15:27 +00:00
										 |  |  |  |     authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check") | 
					
						
							| 
									
										
										
										
											2006-06-13 22:24:47 +00:00
										 |  |  |  |     return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite)) | 
					
						
							| 
									
										
										
										
											2006-04-01 00:57:31 +00:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | def test(): | 
					
						
							|  |  |  |  |     runner = unittest.TextTestRunner() | 
					
						
							|  |  |  |  |     runner.run(suite()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |  |     test() |