python-2.5.2/win32/Lib/sqlite3/test/userfunctions.py
changeset 0 ae805ac0140d
equal deleted inserted replaced
-1:000000000000 0:ae805ac0140d
       
     1 #-*- coding: ISO-8859-1 -*-
       
     2 # pysqlite2/test/userfunctions.py: tests for user-defined functions and
       
     3 #                                  aggregates.
       
     4 #
       
     5 # Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
       
     6 #
       
     7 # This file is part of pysqlite.
       
     8 #
       
     9 # This software is provided 'as-is', without any express or implied
       
    10 # warranty.  In no event will the authors be held liable for any damages
       
    11 # arising from the use of this software.
       
    12 #
       
    13 # Permission is granted to anyone to use this software for any purpose,
       
    14 # including commercial applications, and to alter it and redistribute it
       
    15 # freely, subject to the following restrictions:
       
    16 #
       
    17 # 1. The origin of this software must not be misrepresented; you must not
       
    18 #    claim that you wrote the original software. If you use this software
       
    19 #    in a product, an acknowledgment in the product documentation would be
       
    20 #    appreciated but is not required.
       
    21 # 2. Altered source versions must be plainly marked as such, and must not be
       
    22 #    misrepresented as being the original software.
       
    23 # 3. This notice may not be removed or altered from any source distribution.
       
    24 
       
    25 import unittest
       
    26 import sqlite3 as sqlite
       
    27 
       
    28 def func_returntext():
       
    29     return "foo"
       
    30 def func_returnunicode():
       
    31     return u"bar"
       
    32 def func_returnint():
       
    33     return 42
       
    34 def func_returnfloat():
       
    35     return 3.14
       
    36 def func_returnnull():
       
    37     return None
       
    38 def func_returnblob():
       
    39     return buffer("blob")
       
    40 def func_raiseexception():
       
    41     5/0
       
    42 
       
    43 def func_isstring(v):
       
    44     return type(v) is unicode
       
    45 def func_isint(v):
       
    46     return type(v) is int
       
    47 def func_isfloat(v):
       
    48     return type(v) is float
       
    49 def func_isnone(v):
       
    50     return type(v) is type(None)
       
    51 def func_isblob(v):
       
    52     return type(v) is buffer
       
    53 
       
    54 class AggrNoStep:
       
    55     def __init__(self):
       
    56         pass
       
    57 
       
    58     def finalize(self):
       
    59         return 1
       
    60 
       
    61 class AggrNoFinalize:
       
    62     def __init__(self):
       
    63         pass
       
    64 
       
    65     def step(self, x):
       
    66         pass
       
    67 
       
    68 class AggrExceptionInInit:
       
    69     def __init__(self):
       
    70         5/0
       
    71 
       
    72     def step(self, x):
       
    73         pass
       
    74 
       
    75     def finalize(self):
       
    76         pass
       
    77 
       
    78 class AggrExceptionInStep:
       
    79     def __init__(self):
       
    80         pass
       
    81 
       
    82     def step(self, x):
       
    83         5/0
       
    84 
       
    85     def finalize(self):
       
    86         return 42
       
    87 
       
    88 class AggrExceptionInFinalize:
       
    89     def __init__(self):
       
    90         pass
       
    91 
       
    92     def step(self, x):
       
    93         pass
       
    94 
       
    95     def finalize(self):
       
    96         5/0
       
    97 
       
    98 class AggrCheckType:
       
    99     def __init__(self):
       
   100         self.val = None
       
   101 
       
   102     def step(self, whichType, val):
       
   103         theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
       
   104         self.val = int(theType[whichType] is type(val))
       
   105 
       
   106     def finalize(self):
       
   107         return self.val
       
   108 
       
   109 class AggrSum:
       
   110     def __init__(self):
       
   111         self.val = 0.0
       
   112 
       
   113     def step(self, val):
       
   114         self.val += val
       
   115 
       
   116     def finalize(self):
       
   117         return self.val
       
   118 
       
   119 class FunctionTests(unittest.TestCase):
       
   120     def setUp(self):
       
   121         self.con = sqlite.connect(":memory:")
       
   122 
       
   123         self.con.create_function("returntext", 0, func_returntext)
       
   124         self.con.create_function("returnunicode", 0, func_returnunicode)
       
   125         self.con.create_function("returnint", 0, func_returnint)
       
   126         self.con.create_function("returnfloat", 0, func_returnfloat)
       
   127         self.con.create_function("returnnull", 0, func_returnnull)
       
   128         self.con.create_function("returnblob", 0, func_returnblob)
       
   129         self.con.create_function("raiseexception", 0, func_raiseexception)
       
   130 
       
   131         self.con.create_function("isstring", 1, func_isstring)
       
   132         self.con.create_function("isint", 1, func_isint)
       
   133         self.con.create_function("isfloat", 1, func_isfloat)
       
   134         self.con.create_function("isnone", 1, func_isnone)
       
   135         self.con.create_function("isblob", 1, func_isblob)
       
   136 
       
   137     def tearDown(self):
       
   138         self.con.close()
       
   139 
       
   140     def CheckFuncErrorOnCreate(self):
       
   141         try:
       
   142             self.con.create_function("bla", -100, lambda x: 2*x)
       
   143             self.fail("should have raised an OperationalError")
       
   144         except sqlite.OperationalError:
       
   145             pass
       
   146 
       
   147     def CheckFuncRefCount(self):
       
   148         def getfunc():
       
   149             def f():
       
   150                 return 1
       
   151             return f
       
   152         f = getfunc()
       
   153         globals()["foo"] = f
       
   154         # self.con.create_function("reftest", 0, getfunc())
       
   155         self.con.create_function("reftest", 0, f)
       
   156         cur = self.con.cursor()
       
   157         cur.execute("select reftest()")
       
   158 
       
   159     def CheckFuncReturnText(self):
       
   160         cur = self.con.cursor()
       
   161         cur.execute("select returntext()")
       
   162         val = cur.fetchone()[0]
       
   163         self.failUnlessEqual(type(val), unicode)
       
   164         self.failUnlessEqual(val, "foo")
       
   165 
       
   166     def CheckFuncReturnUnicode(self):
       
   167         cur = self.con.cursor()
       
   168         cur.execute("select returnunicode()")
       
   169         val = cur.fetchone()[0]
       
   170         self.failUnlessEqual(type(val), unicode)
       
   171         self.failUnlessEqual(val, u"bar")
       
   172 
       
   173     def CheckFuncReturnInt(self):
       
   174         cur = self.con.cursor()
       
   175         cur.execute("select returnint()")
       
   176         val = cur.fetchone()[0]
       
   177         self.failUnlessEqual(type(val), int)
       
   178         self.failUnlessEqual(val, 42)
       
   179 
       
   180     def CheckFuncReturnFloat(self):
       
   181         cur = self.con.cursor()
       
   182         cur.execute("select returnfloat()")
       
   183         val = cur.fetchone()[0]
       
   184         self.failUnlessEqual(type(val), float)
       
   185         if val < 3.139 or val > 3.141:
       
   186             self.fail("wrong value")
       
   187 
       
   188     def CheckFuncReturnNull(self):
       
   189         cur = self.con.cursor()
       
   190         cur.execute("select returnnull()")
       
   191         val = cur.fetchone()[0]
       
   192         self.failUnlessEqual(type(val), type(None))
       
   193         self.failUnlessEqual(val, None)
       
   194 
       
   195     def CheckFuncReturnBlob(self):
       
   196         cur = self.con.cursor()
       
   197         cur.execute("select returnblob()")
       
   198         val = cur.fetchone()[0]
       
   199         self.failUnlessEqual(type(val), buffer)
       
   200         self.failUnlessEqual(val, buffer("blob"))
       
   201 
       
   202     def CheckFuncException(self):
       
   203         cur = self.con.cursor()
       
   204         try:
       
   205             cur.execute("select raiseexception()")
       
   206             cur.fetchone()
       
   207             self.fail("should have raised OperationalError")
       
   208         except sqlite.OperationalError, e:
       
   209             self.failUnlessEqual(e.args[0], 'user-defined function raised exception')
       
   210 
       
   211     def CheckParamString(self):
       
   212         cur = self.con.cursor()
       
   213         cur.execute("select isstring(?)", ("foo",))
       
   214         val = cur.fetchone()[0]
       
   215         self.failUnlessEqual(val, 1)
       
   216 
       
   217     def CheckParamInt(self):
       
   218         cur = self.con.cursor()
       
   219         cur.execute("select isint(?)", (42,))
       
   220         val = cur.fetchone()[0]
       
   221         self.failUnlessEqual(val, 1)
       
   222 
       
   223     def CheckParamFloat(self):
       
   224         cur = self.con.cursor()
       
   225         cur.execute("select isfloat(?)", (3.14,))
       
   226         val = cur.fetchone()[0]
       
   227         self.failUnlessEqual(val, 1)
       
   228 
       
   229     def CheckParamNone(self):
       
   230         cur = self.con.cursor()
       
   231         cur.execute("select isnone(?)", (None,))
       
   232         val = cur.fetchone()[0]
       
   233         self.failUnlessEqual(val, 1)
       
   234 
       
   235     def CheckParamBlob(self):
       
   236         cur = self.con.cursor()
       
   237         cur.execute("select isblob(?)", (buffer("blob"),))
       
   238         val = cur.fetchone()[0]
       
   239         self.failUnlessEqual(val, 1)
       
   240 
       
   241 class AggregateTests(unittest.TestCase):
       
   242     def setUp(self):
       
   243         self.con = sqlite.connect(":memory:")
       
   244         cur = self.con.cursor()
       
   245         cur.execute("""
       
   246             create table test(
       
   247                 t text,
       
   248                 i integer,
       
   249                 f float,
       
   250                 n,
       
   251                 b blob
       
   252                 )
       
   253             """)
       
   254         cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
       
   255             ("foo", 5, 3.14, None, buffer("blob"),))
       
   256 
       
   257         self.con.create_aggregate("nostep", 1, AggrNoStep)
       
   258         self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
       
   259         self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
       
   260         self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
       
   261         self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
       
   262         self.con.create_aggregate("checkType", 2, AggrCheckType)
       
   263         self.con.create_aggregate("mysum", 1, AggrSum)
       
   264 
       
   265     def tearDown(self):
       
   266         #self.cur.close()
       
   267         #self.con.close()
       
   268         pass
       
   269 
       
   270     def CheckAggrErrorOnCreate(self):
       
   271         try:
       
   272             self.con.create_function("bla", -100, AggrSum)
       
   273             self.fail("should have raised an OperationalError")
       
   274         except sqlite.OperationalError:
       
   275             pass
       
   276 
       
   277     def CheckAggrNoStep(self):
       
   278         cur = self.con.cursor()
       
   279         try:
       
   280             cur.execute("select nostep(t) from test")
       
   281             self.fail("should have raised an AttributeError")
       
   282         except AttributeError, e:
       
   283             self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")
       
   284 
       
   285     def CheckAggrNoFinalize(self):
       
   286         cur = self.con.cursor()
       
   287         try:
       
   288             cur.execute("select nofinalize(t) from test")
       
   289             val = cur.fetchone()[0]
       
   290             self.fail("should have raised an OperationalError")
       
   291         except sqlite.OperationalError, e:
       
   292             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
       
   293 
       
   294     def CheckAggrExceptionInInit(self):
       
   295         cur = self.con.cursor()
       
   296         try:
       
   297             cur.execute("select excInit(t) from test")
       
   298             val = cur.fetchone()[0]
       
   299             self.fail("should have raised an OperationalError")
       
   300         except sqlite.OperationalError, e:
       
   301             self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")
       
   302 
       
   303     def CheckAggrExceptionInStep(self):
       
   304         cur = self.con.cursor()
       
   305         try:
       
   306             cur.execute("select excStep(t) from test")
       
   307             val = cur.fetchone()[0]
       
   308             self.fail("should have raised an OperationalError")
       
   309         except sqlite.OperationalError, e:
       
   310             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")
       
   311 
       
   312     def CheckAggrExceptionInFinalize(self):
       
   313         cur = self.con.cursor()
       
   314         try:
       
   315             cur.execute("select excFinalize(t) from test")
       
   316             val = cur.fetchone()[0]
       
   317             self.fail("should have raised an OperationalError")
       
   318         except sqlite.OperationalError, e:
       
   319             self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")
       
   320 
       
   321     def CheckAggrCheckParamStr(self):
       
   322         cur = self.con.cursor()
       
   323         cur.execute("select checkType('str', ?)", ("foo",))
       
   324         val = cur.fetchone()[0]
       
   325         self.failUnlessEqual(val, 1)
       
   326 
       
   327     def CheckAggrCheckParamInt(self):
       
   328         cur = self.con.cursor()
       
   329         cur.execute("select checkType('int', ?)", (42,))
       
   330         val = cur.fetchone()[0]
       
   331         self.failUnlessEqual(val, 1)
       
   332 
       
   333     def CheckAggrCheckParamFloat(self):
       
   334         cur = self.con.cursor()
       
   335         cur.execute("select checkType('float', ?)", (3.14,))
       
   336         val = cur.fetchone()[0]
       
   337         self.failUnlessEqual(val, 1)
       
   338 
       
   339     def CheckAggrCheckParamNone(self):
       
   340         cur = self.con.cursor()
       
   341         cur.execute("select checkType('None', ?)", (None,))
       
   342         val = cur.fetchone()[0]
       
   343         self.failUnlessEqual(val, 1)
       
   344 
       
   345     def CheckAggrCheckParamBlob(self):
       
   346         cur = self.con.cursor()
       
   347         cur.execute("select checkType('blob', ?)", (buffer("blob"),))
       
   348         val = cur.fetchone()[0]
       
   349         self.failUnlessEqual(val, 1)
       
   350 
       
   351     def CheckAggrCheckAggrSum(self):
       
   352         cur = self.con.cursor()
       
   353         cur.execute("delete from test")
       
   354         cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
       
   355         cur.execute("select mysum(i) from test")
       
   356         val = cur.fetchone()[0]
       
   357         self.failUnlessEqual(val, 60)
       
   358 
       
   359 def authorizer_cb(action, arg1, arg2, dbname, source):
       
   360     if action != sqlite.SQLITE_SELECT:
       
   361         return sqlite.SQLITE_DENY
       
   362     if arg2 == 'c2' or arg1 == 't2':
       
   363         return sqlite.SQLITE_DENY
       
   364     return sqlite.SQLITE_OK
       
   365 
       
   366 class AuthorizerTests(unittest.TestCase):
       
   367     def setUp(self):
       
   368         self.con = sqlite.connect(":memory:")
       
   369         self.con.executescript("""
       
   370             create table t1 (c1, c2);
       
   371             create table t2 (c1, c2);
       
   372             insert into t1 (c1, c2) values (1, 2);
       
   373             insert into t2 (c1, c2) values (4, 5);
       
   374             """)
       
   375 
       
   376         # For our security test:
       
   377         self.con.execute("select c2 from t2")
       
   378 
       
   379         self.con.set_authorizer(authorizer_cb)
       
   380 
       
   381     def tearDown(self):
       
   382         pass
       
   383 
       
   384     def CheckTableAccess(self):
       
   385         try:
       
   386             self.con.execute("select * from t2")
       
   387         except sqlite.DatabaseError, e:
       
   388             if not e.args[0].endswith("prohibited"):
       
   389                 self.fail("wrong exception text: %s" % e.args[0])
       
   390             return
       
   391         self.fail("should have raised an exception due to missing privileges")
       
   392 
       
   393     def CheckColumnAccess(self):
       
   394         try:
       
   395             self.con.execute("select c2 from t1")
       
   396         except sqlite.DatabaseError, e:
       
   397             if not e.args[0].endswith("prohibited"):
       
   398                 self.fail("wrong exception text: %s" % e.args[0])
       
   399             return
       
   400         self.fail("should have raised an exception due to missing privileges")
       
   401 
       
   402 def suite():
       
   403     function_suite = unittest.makeSuite(FunctionTests, "Check")
       
   404     aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
       
   405     authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")
       
   406     return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))
       
   407 
       
   408 def test():
       
   409     runner = unittest.TextTestRunner()
       
   410     runner.run(suite())
       
   411 
       
   412 if __name__ == "__main__":
       
   413     test()