python-2.5.2/win32/Lib/test/test_contextlib.py
changeset 0 ae805ac0140d
equal deleted inserted replaced
-1:000000000000 0:ae805ac0140d
       
     1 """Unit tests for contextlib.py, and other context managers."""
       
     2 
       
     3 from __future__ import with_statement
       
     4 
       
     5 import sys
       
     6 import os
       
     7 import decimal
       
     8 import tempfile
       
     9 import unittest
       
    10 import threading
       
    11 from contextlib import *  # Tests __all__
       
    12 from test.test_support import run_suite
       
    13 
       
    14 class ContextManagerTestCase(unittest.TestCase):
       
    15 
       
    16     def test_contextmanager_plain(self):
       
    17         state = []
       
    18         @contextmanager
       
    19         def woohoo():
       
    20             state.append(1)
       
    21             yield 42
       
    22             state.append(999)
       
    23         with woohoo() as x:
       
    24             self.assertEqual(state, [1])
       
    25             self.assertEqual(x, 42)
       
    26             state.append(x)
       
    27         self.assertEqual(state, [1, 42, 999])
       
    28 
       
    29     def test_contextmanager_finally(self):
       
    30         state = []
       
    31         @contextmanager
       
    32         def woohoo():
       
    33             state.append(1)
       
    34             try:
       
    35                 yield 42
       
    36             finally:
       
    37                 state.append(999)
       
    38         try:
       
    39             with woohoo() as x:
       
    40                 self.assertEqual(state, [1])
       
    41                 self.assertEqual(x, 42)
       
    42                 state.append(x)
       
    43                 raise ZeroDivisionError()
       
    44         except ZeroDivisionError:
       
    45             pass
       
    46         else:
       
    47             self.fail("Expected ZeroDivisionError")
       
    48         self.assertEqual(state, [1, 42, 999])
       
    49 
       
    50     def test_contextmanager_no_reraise(self):
       
    51         @contextmanager
       
    52         def whee():
       
    53             yield
       
    54         ctx = whee()
       
    55         ctx.__enter__()
       
    56         # Calling __exit__ should not result in an exception
       
    57         self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
       
    58 
       
    59     def test_contextmanager_trap_yield_after_throw(self):
       
    60         @contextmanager
       
    61         def whoo():
       
    62             try:
       
    63                 yield
       
    64             except:
       
    65                 yield
       
    66         ctx = whoo()
       
    67         ctx.__enter__()
       
    68         self.assertRaises(
       
    69             RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
       
    70         )
       
    71 
       
    72     def test_contextmanager_except(self):
       
    73         state = []
       
    74         @contextmanager
       
    75         def woohoo():
       
    76             state.append(1)
       
    77             try:
       
    78                 yield 42
       
    79             except ZeroDivisionError, e:
       
    80                 state.append(e.args[0])
       
    81                 self.assertEqual(state, [1, 42, 999])
       
    82         with woohoo() as x:
       
    83             self.assertEqual(state, [1])
       
    84             self.assertEqual(x, 42)
       
    85             state.append(x)
       
    86             raise ZeroDivisionError(999)
       
    87         self.assertEqual(state, [1, 42, 999])
       
    88 
       
    89     def test_contextmanager_attribs(self):
       
    90         def attribs(**kw):
       
    91             def decorate(func):
       
    92                 for k,v in kw.items():
       
    93                     setattr(func,k,v)
       
    94                 return func
       
    95             return decorate
       
    96         @contextmanager
       
    97         @attribs(foo='bar')
       
    98         def baz(spam):
       
    99             """Whee!"""
       
   100         self.assertEqual(baz.__name__,'baz')
       
   101         self.assertEqual(baz.foo, 'bar')
       
   102         self.assertEqual(baz.__doc__, "Whee!")
       
   103 
       
   104 class NestedTestCase(unittest.TestCase):
       
   105 
       
   106     # XXX This needs more work
       
   107 
       
   108     def test_nested(self):
       
   109         @contextmanager
       
   110         def a():
       
   111             yield 1
       
   112         @contextmanager
       
   113         def b():
       
   114             yield 2
       
   115         @contextmanager
       
   116         def c():
       
   117             yield 3
       
   118         with nested(a(), b(), c()) as (x, y, z):
       
   119             self.assertEqual(x, 1)
       
   120             self.assertEqual(y, 2)
       
   121             self.assertEqual(z, 3)
       
   122 
       
   123     def test_nested_cleanup(self):
       
   124         state = []
       
   125         @contextmanager
       
   126         def a():
       
   127             state.append(1)
       
   128             try:
       
   129                 yield 2
       
   130             finally:
       
   131                 state.append(3)
       
   132         @contextmanager
       
   133         def b():
       
   134             state.append(4)
       
   135             try:
       
   136                 yield 5
       
   137             finally:
       
   138                 state.append(6)
       
   139         try:
       
   140             with nested(a(), b()) as (x, y):
       
   141                 state.append(x)
       
   142                 state.append(y)
       
   143                 1/0
       
   144         except ZeroDivisionError:
       
   145             self.assertEqual(state, [1, 4, 2, 5, 6, 3])
       
   146         else:
       
   147             self.fail("Didn't raise ZeroDivisionError")
       
   148 
       
   149     def test_nested_right_exception(self):
       
   150         state = []
       
   151         @contextmanager
       
   152         def a():
       
   153             yield 1
       
   154         class b(object):
       
   155             def __enter__(self):
       
   156                 return 2
       
   157             def __exit__(self, *exc_info):
       
   158                 try:
       
   159                     raise Exception()
       
   160                 except:
       
   161                     pass
       
   162         try:
       
   163             with nested(a(), b()) as (x, y):
       
   164                 1/0
       
   165         except ZeroDivisionError:
       
   166             self.assertEqual((x, y), (1, 2))
       
   167         except Exception:
       
   168             self.fail("Reraised wrong exception")
       
   169         else:
       
   170             self.fail("Didn't raise ZeroDivisionError")
       
   171 
       
   172     def test_nested_b_swallows(self):
       
   173         @contextmanager
       
   174         def a():
       
   175             yield
       
   176         @contextmanager
       
   177         def b():
       
   178             try:
       
   179                 yield
       
   180             except:
       
   181                 # Swallow the exception
       
   182                 pass
       
   183         try:
       
   184             with nested(a(), b()):
       
   185                 1/0
       
   186         except ZeroDivisionError:
       
   187             self.fail("Didn't swallow ZeroDivisionError")
       
   188 
       
   189     def test_nested_break(self):
       
   190         @contextmanager
       
   191         def a():
       
   192             yield
       
   193         state = 0
       
   194         while True:
       
   195             state += 1
       
   196             with nested(a(), a()):
       
   197                 break
       
   198             state += 10
       
   199         self.assertEqual(state, 1)
       
   200 
       
   201     def test_nested_continue(self):
       
   202         @contextmanager
       
   203         def a():
       
   204             yield
       
   205         state = 0
       
   206         while state < 3:
       
   207             state += 1
       
   208             with nested(a(), a()):
       
   209                 continue
       
   210             state += 10
       
   211         self.assertEqual(state, 3)
       
   212 
       
   213     def test_nested_return(self):
       
   214         @contextmanager
       
   215         def a():
       
   216             try:
       
   217                 yield
       
   218             except:
       
   219                 pass
       
   220         def foo():
       
   221             with nested(a(), a()):
       
   222                 return 1
       
   223             return 10
       
   224         self.assertEqual(foo(), 1)
       
   225 
       
   226 class ClosingTestCase(unittest.TestCase):
       
   227 
       
   228     # XXX This needs more work
       
   229 
       
   230     def test_closing(self):
       
   231         state = []
       
   232         class C:
       
   233             def close(self):
       
   234                 state.append(1)
       
   235         x = C()
       
   236         self.assertEqual(state, [])
       
   237         with closing(x) as y:
       
   238             self.assertEqual(x, y)
       
   239         self.assertEqual(state, [1])
       
   240 
       
   241     def test_closing_error(self):
       
   242         state = []
       
   243         class C:
       
   244             def close(self):
       
   245                 state.append(1)
       
   246         x = C()
       
   247         self.assertEqual(state, [])
       
   248         try:
       
   249             with closing(x) as y:
       
   250                 self.assertEqual(x, y)
       
   251                 1/0
       
   252         except ZeroDivisionError:
       
   253             self.assertEqual(state, [1])
       
   254         else:
       
   255             self.fail("Didn't raise ZeroDivisionError")
       
   256 
       
   257 class FileContextTestCase(unittest.TestCase):
       
   258 
       
   259     def testWithOpen(self):
       
   260         tfn = tempfile.mktemp()
       
   261         try:
       
   262             f = None
       
   263             with open(tfn, "w") as f:
       
   264                 self.failIf(f.closed)
       
   265                 f.write("Booh\n")
       
   266             self.failUnless(f.closed)
       
   267             f = None
       
   268             try:
       
   269                 with open(tfn, "r") as f:
       
   270                     self.failIf(f.closed)
       
   271                     self.assertEqual(f.read(), "Booh\n")
       
   272                     1/0
       
   273             except ZeroDivisionError:
       
   274                 self.failUnless(f.closed)
       
   275             else:
       
   276                 self.fail("Didn't raise ZeroDivisionError")
       
   277         finally:
       
   278             try:
       
   279                 os.remove(tfn)
       
   280             except os.error:
       
   281                 pass
       
   282 
       
   283 class LockContextTestCase(unittest.TestCase):
       
   284 
       
   285     def boilerPlate(self, lock, locked):
       
   286         self.failIf(locked())
       
   287         with lock:
       
   288             self.failUnless(locked())
       
   289         self.failIf(locked())
       
   290         try:
       
   291             with lock:
       
   292                 self.failUnless(locked())
       
   293                 1/0
       
   294         except ZeroDivisionError:
       
   295             self.failIf(locked())
       
   296         else:
       
   297             self.fail("Didn't raise ZeroDivisionError")
       
   298 
       
   299     def testWithLock(self):
       
   300         lock = threading.Lock()
       
   301         self.boilerPlate(lock, lock.locked)
       
   302 
       
   303     def testWithRLock(self):
       
   304         lock = threading.RLock()
       
   305         self.boilerPlate(lock, lock._is_owned)
       
   306 
       
   307     def testWithCondition(self):
       
   308         lock = threading.Condition()
       
   309         def locked():
       
   310             return lock._is_owned()
       
   311         self.boilerPlate(lock, locked)
       
   312 
       
   313     def testWithSemaphore(self):
       
   314         lock = threading.Semaphore()
       
   315         def locked():
       
   316             if lock.acquire(False):
       
   317                 lock.release()
       
   318                 return False
       
   319             else:
       
   320                 return True
       
   321         self.boilerPlate(lock, locked)
       
   322 
       
   323     def testWithBoundedSemaphore(self):
       
   324         lock = threading.BoundedSemaphore()
       
   325         def locked():
       
   326             if lock.acquire(False):
       
   327                 lock.release()
       
   328                 return False
       
   329             else:
       
   330                 return True
       
   331         self.boilerPlate(lock, locked)
       
   332 
       
   333 # This is needed to make the test actually run under regrtest.py!
       
   334 def test_main():
       
   335     run_suite(
       
   336         unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
       
   337     )
       
   338 
       
   339 if __name__ == "__main__":
       
   340     test_main()