|
1 #-*- coding: ISO-8859-1 -*- |
|
2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and |
|
3 # aggregates. |
|
4 # |
|
5 # Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> |
|
6 # |
|
7 # This file is part of pysqlite. |
|
8 # |
|
9 # This software is provided 'as-is', without any express or implied |
|
10 # warranty. In no event will the authors be held liable for any damages |
|
11 # arising from the use of this software. |
|
12 # |
|
13 # Permission is granted to anyone to use this software for any purpose, |
|
14 # including commercial applications, and to alter it and redistribute it |
|
15 # freely, subject to the following restrictions: |
|
16 # |
|
17 # 1. The origin of this software must not be misrepresented; you must not |
|
18 # claim that you wrote the original software. If you use this software |
|
19 # in a product, an acknowledgment in the product documentation would be |
|
20 # appreciated but is not required. |
|
21 # 2. Altered source versions must be plainly marked as such, and must not be |
|
22 # misrepresented as being the original software. |
|
23 # 3. This notice may not be removed or altered from any source distribution. |
|
24 |
|
25 import unittest |
|
26 import sqlite3 as sqlite |
|
27 |
|
28 def func_returntext(): |
|
29 return "foo" |
|
30 def func_returnunicode(): |
|
31 return u"bar" |
|
32 def func_returnint(): |
|
33 return 42 |
|
34 def func_returnfloat(): |
|
35 return 3.14 |
|
36 def func_returnnull(): |
|
37 return None |
|
38 def func_returnblob(): |
|
39 return buffer("blob") |
|
40 def func_raiseexception(): |
|
41 5/0 |
|
42 |
|
43 def func_isstring(v): |
|
44 return type(v) is unicode |
|
45 def func_isint(v): |
|
46 return type(v) is int |
|
47 def func_isfloat(v): |
|
48 return type(v) is float |
|
49 def func_isnone(v): |
|
50 return type(v) is type(None) |
|
51 def func_isblob(v): |
|
52 return type(v) is buffer |
|
53 |
|
54 class AggrNoStep: |
|
55 def __init__(self): |
|
56 pass |
|
57 |
|
58 def finalize(self): |
|
59 return 1 |
|
60 |
|
61 class AggrNoFinalize: |
|
62 def __init__(self): |
|
63 pass |
|
64 |
|
65 def step(self, x): |
|
66 pass |
|
67 |
|
68 class AggrExceptionInInit: |
|
69 def __init__(self): |
|
70 5/0 |
|
71 |
|
72 def step(self, x): |
|
73 pass |
|
74 |
|
75 def finalize(self): |
|
76 pass |
|
77 |
|
78 class AggrExceptionInStep: |
|
79 def __init__(self): |
|
80 pass |
|
81 |
|
82 def step(self, x): |
|
83 5/0 |
|
84 |
|
85 def finalize(self): |
|
86 return 42 |
|
87 |
|
88 class AggrExceptionInFinalize: |
|
89 def __init__(self): |
|
90 pass |
|
91 |
|
92 def step(self, x): |
|
93 pass |
|
94 |
|
95 def finalize(self): |
|
96 5/0 |
|
97 |
|
98 class AggrCheckType: |
|
99 def __init__(self): |
|
100 self.val = None |
|
101 |
|
102 def step(self, whichType, val): |
|
103 theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer} |
|
104 self.val = int(theType[whichType] is type(val)) |
|
105 |
|
106 def finalize(self): |
|
107 return self.val |
|
108 |
|
109 class AggrSum: |
|
110 def __init__(self): |
|
111 self.val = 0.0 |
|
112 |
|
113 def step(self, val): |
|
114 self.val += val |
|
115 |
|
116 def finalize(self): |
|
117 return self.val |
|
118 |
|
119 class FunctionTests(unittest.TestCase): |
|
120 def setUp(self): |
|
121 self.con = sqlite.connect(":memory:") |
|
122 |
|
123 self.con.create_function("returntext", 0, func_returntext) |
|
124 self.con.create_function("returnunicode", 0, func_returnunicode) |
|
125 self.con.create_function("returnint", 0, func_returnint) |
|
126 self.con.create_function("returnfloat", 0, func_returnfloat) |
|
127 self.con.create_function("returnnull", 0, func_returnnull) |
|
128 self.con.create_function("returnblob", 0, func_returnblob) |
|
129 self.con.create_function("raiseexception", 0, func_raiseexception) |
|
130 |
|
131 self.con.create_function("isstring", 1, func_isstring) |
|
132 self.con.create_function("isint", 1, func_isint) |
|
133 self.con.create_function("isfloat", 1, func_isfloat) |
|
134 self.con.create_function("isnone", 1, func_isnone) |
|
135 self.con.create_function("isblob", 1, func_isblob) |
|
136 |
|
137 def tearDown(self): |
|
138 self.con.close() |
|
139 |
|
140 def CheckFuncErrorOnCreate(self): |
|
141 try: |
|
142 self.con.create_function("bla", -100, lambda x: 2*x) |
|
143 self.fail("should have raised an OperationalError") |
|
144 except sqlite.OperationalError: |
|
145 pass |
|
146 |
|
147 def CheckFuncRefCount(self): |
|
148 def getfunc(): |
|
149 def f(): |
|
150 return 1 |
|
151 return f |
|
152 f = getfunc() |
|
153 globals()["foo"] = f |
|
154 # self.con.create_function("reftest", 0, getfunc()) |
|
155 self.con.create_function("reftest", 0, f) |
|
156 cur = self.con.cursor() |
|
157 cur.execute("select reftest()") |
|
158 |
|
159 def CheckFuncReturnText(self): |
|
160 cur = self.con.cursor() |
|
161 cur.execute("select returntext()") |
|
162 val = cur.fetchone()[0] |
|
163 self.failUnlessEqual(type(val), unicode) |
|
164 self.failUnlessEqual(val, "foo") |
|
165 |
|
166 def CheckFuncReturnUnicode(self): |
|
167 cur = self.con.cursor() |
|
168 cur.execute("select returnunicode()") |
|
169 val = cur.fetchone()[0] |
|
170 self.failUnlessEqual(type(val), unicode) |
|
171 self.failUnlessEqual(val, u"bar") |
|
172 |
|
173 def CheckFuncReturnInt(self): |
|
174 cur = self.con.cursor() |
|
175 cur.execute("select returnint()") |
|
176 val = cur.fetchone()[0] |
|
177 self.failUnlessEqual(type(val), int) |
|
178 self.failUnlessEqual(val, 42) |
|
179 |
|
180 def CheckFuncReturnFloat(self): |
|
181 cur = self.con.cursor() |
|
182 cur.execute("select returnfloat()") |
|
183 val = cur.fetchone()[0] |
|
184 self.failUnlessEqual(type(val), float) |
|
185 if val < 3.139 or val > 3.141: |
|
186 self.fail("wrong value") |
|
187 |
|
188 def CheckFuncReturnNull(self): |
|
189 cur = self.con.cursor() |
|
190 cur.execute("select returnnull()") |
|
191 val = cur.fetchone()[0] |
|
192 self.failUnlessEqual(type(val), type(None)) |
|
193 self.failUnlessEqual(val, None) |
|
194 |
|
195 def CheckFuncReturnBlob(self): |
|
196 cur = self.con.cursor() |
|
197 cur.execute("select returnblob()") |
|
198 val = cur.fetchone()[0] |
|
199 self.failUnlessEqual(type(val), buffer) |
|
200 self.failUnlessEqual(val, buffer("blob")) |
|
201 |
|
202 def CheckFuncException(self): |
|
203 cur = self.con.cursor() |
|
204 try: |
|
205 cur.execute("select raiseexception()") |
|
206 cur.fetchone() |
|
207 self.fail("should have raised OperationalError") |
|
208 except sqlite.OperationalError, e: |
|
209 self.failUnlessEqual(e.args[0], 'user-defined function raised exception') |
|
210 |
|
211 def CheckParamString(self): |
|
212 cur = self.con.cursor() |
|
213 cur.execute("select isstring(?)", ("foo",)) |
|
214 val = cur.fetchone()[0] |
|
215 self.failUnlessEqual(val, 1) |
|
216 |
|
217 def CheckParamInt(self): |
|
218 cur = self.con.cursor() |
|
219 cur.execute("select isint(?)", (42,)) |
|
220 val = cur.fetchone()[0] |
|
221 self.failUnlessEqual(val, 1) |
|
222 |
|
223 def CheckParamFloat(self): |
|
224 cur = self.con.cursor() |
|
225 cur.execute("select isfloat(?)", (3.14,)) |
|
226 val = cur.fetchone()[0] |
|
227 self.failUnlessEqual(val, 1) |
|
228 |
|
229 def CheckParamNone(self): |
|
230 cur = self.con.cursor() |
|
231 cur.execute("select isnone(?)", (None,)) |
|
232 val = cur.fetchone()[0] |
|
233 self.failUnlessEqual(val, 1) |
|
234 |
|
235 def CheckParamBlob(self): |
|
236 cur = self.con.cursor() |
|
237 cur.execute("select isblob(?)", (buffer("blob"),)) |
|
238 val = cur.fetchone()[0] |
|
239 self.failUnlessEqual(val, 1) |
|
240 |
|
241 class AggregateTests(unittest.TestCase): |
|
242 def setUp(self): |
|
243 self.con = sqlite.connect(":memory:") |
|
244 cur = self.con.cursor() |
|
245 cur.execute(""" |
|
246 create table test( |
|
247 t text, |
|
248 i integer, |
|
249 f float, |
|
250 n, |
|
251 b blob |
|
252 ) |
|
253 """) |
|
254 cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", |
|
255 ("foo", 5, 3.14, None, buffer("blob"),)) |
|
256 |
|
257 self.con.create_aggregate("nostep", 1, AggrNoStep) |
|
258 self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) |
|
259 self.con.create_aggregate("excInit", 1, AggrExceptionInInit) |
|
260 self.con.create_aggregate("excStep", 1, AggrExceptionInStep) |
|
261 self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) |
|
262 self.con.create_aggregate("checkType", 2, AggrCheckType) |
|
263 self.con.create_aggregate("mysum", 1, AggrSum) |
|
264 |
|
265 def tearDown(self): |
|
266 #self.cur.close() |
|
267 #self.con.close() |
|
268 pass |
|
269 |
|
270 def CheckAggrErrorOnCreate(self): |
|
271 try: |
|
272 self.con.create_function("bla", -100, AggrSum) |
|
273 self.fail("should have raised an OperationalError") |
|
274 except sqlite.OperationalError: |
|
275 pass |
|
276 |
|
277 def CheckAggrNoStep(self): |
|
278 cur = self.con.cursor() |
|
279 try: |
|
280 cur.execute("select nostep(t) from test") |
|
281 self.fail("should have raised an AttributeError") |
|
282 except AttributeError, e: |
|
283 self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") |
|
284 |
|
285 def CheckAggrNoFinalize(self): |
|
286 cur = self.con.cursor() |
|
287 try: |
|
288 cur.execute("select nofinalize(t) from test") |
|
289 val = cur.fetchone()[0] |
|
290 self.fail("should have raised an OperationalError") |
|
291 except sqlite.OperationalError, e: |
|
292 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") |
|
293 |
|
294 def CheckAggrExceptionInInit(self): |
|
295 cur = self.con.cursor() |
|
296 try: |
|
297 cur.execute("select excInit(t) from test") |
|
298 val = cur.fetchone()[0] |
|
299 self.fail("should have raised an OperationalError") |
|
300 except sqlite.OperationalError, e: |
|
301 self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") |
|
302 |
|
303 def CheckAggrExceptionInStep(self): |
|
304 cur = self.con.cursor() |
|
305 try: |
|
306 cur.execute("select excStep(t) from test") |
|
307 val = cur.fetchone()[0] |
|
308 self.fail("should have raised an OperationalError") |
|
309 except sqlite.OperationalError, e: |
|
310 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error") |
|
311 |
|
312 def CheckAggrExceptionInFinalize(self): |
|
313 cur = self.con.cursor() |
|
314 try: |
|
315 cur.execute("select excFinalize(t) from test") |
|
316 val = cur.fetchone()[0] |
|
317 self.fail("should have raised an OperationalError") |
|
318 except sqlite.OperationalError, e: |
|
319 self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") |
|
320 |
|
321 def CheckAggrCheckParamStr(self): |
|
322 cur = self.con.cursor() |
|
323 cur.execute("select checkType('str', ?)", ("foo",)) |
|
324 val = cur.fetchone()[0] |
|
325 self.failUnlessEqual(val, 1) |
|
326 |
|
327 def CheckAggrCheckParamInt(self): |
|
328 cur = self.con.cursor() |
|
329 cur.execute("select checkType('int', ?)", (42,)) |
|
330 val = cur.fetchone()[0] |
|
331 self.failUnlessEqual(val, 1) |
|
332 |
|
333 def CheckAggrCheckParamFloat(self): |
|
334 cur = self.con.cursor() |
|
335 cur.execute("select checkType('float', ?)", (3.14,)) |
|
336 val = cur.fetchone()[0] |
|
337 self.failUnlessEqual(val, 1) |
|
338 |
|
339 def CheckAggrCheckParamNone(self): |
|
340 cur = self.con.cursor() |
|
341 cur.execute("select checkType('None', ?)", (None,)) |
|
342 val = cur.fetchone()[0] |
|
343 self.failUnlessEqual(val, 1) |
|
344 |
|
345 def CheckAggrCheckParamBlob(self): |
|
346 cur = self.con.cursor() |
|
347 cur.execute("select checkType('blob', ?)", (buffer("blob"),)) |
|
348 val = cur.fetchone()[0] |
|
349 self.failUnlessEqual(val, 1) |
|
350 |
|
351 def CheckAggrCheckAggrSum(self): |
|
352 cur = self.con.cursor() |
|
353 cur.execute("delete from test") |
|
354 cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) |
|
355 cur.execute("select mysum(i) from test") |
|
356 val = cur.fetchone()[0] |
|
357 self.failUnlessEqual(val, 60) |
|
358 |
|
359 def authorizer_cb(action, arg1, arg2, dbname, source): |
|
360 if action != sqlite.SQLITE_SELECT: |
|
361 return sqlite.SQLITE_DENY |
|
362 if arg2 == 'c2' or arg1 == 't2': |
|
363 return sqlite.SQLITE_DENY |
|
364 return sqlite.SQLITE_OK |
|
365 |
|
366 class AuthorizerTests(unittest.TestCase): |
|
367 def setUp(self): |
|
368 self.con = sqlite.connect(":memory:") |
|
369 self.con.executescript(""" |
|
370 create table t1 (c1, c2); |
|
371 create table t2 (c1, c2); |
|
372 insert into t1 (c1, c2) values (1, 2); |
|
373 insert into t2 (c1, c2) values (4, 5); |
|
374 """) |
|
375 |
|
376 # For our security test: |
|
377 self.con.execute("select c2 from t2") |
|
378 |
|
379 self.con.set_authorizer(authorizer_cb) |
|
380 |
|
381 def tearDown(self): |
|
382 pass |
|
383 |
|
384 def CheckTableAccess(self): |
|
385 try: |
|
386 self.con.execute("select * from t2") |
|
387 except sqlite.DatabaseError, e: |
|
388 if not e.args[0].endswith("prohibited"): |
|
389 self.fail("wrong exception text: %s" % e.args[0]) |
|
390 return |
|
391 self.fail("should have raised an exception due to missing privileges") |
|
392 |
|
393 def CheckColumnAccess(self): |
|
394 try: |
|
395 self.con.execute("select c2 from t1") |
|
396 except sqlite.DatabaseError, e: |
|
397 if not e.args[0].endswith("prohibited"): |
|
398 self.fail("wrong exception text: %s" % e.args[0]) |
|
399 return |
|
400 self.fail("should have raised an exception due to missing privileges") |
|
401 |
|
402 def suite(): |
|
403 function_suite = unittest.makeSuite(FunctionTests, "Check") |
|
404 aggregate_suite = unittest.makeSuite(AggregateTests, "Check") |
|
405 authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check") |
|
406 return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite)) |
|
407 |
|
408 def test(): |
|
409 runner = unittest.TextTestRunner() |
|
410 runner.run(suite()) |
|
411 |
|
412 if __name__ == "__main__": |
|
413 test() |