|
1 import sys |
|
2 import imp |
|
3 import os |
|
4 import unittest |
|
5 from test import test_support |
|
6 |
|
7 |
|
8 test_src = """\ |
|
9 def get_name(): |
|
10 return __name__ |
|
11 def get_file(): |
|
12 return __file__ |
|
13 """ |
|
14 |
|
15 absimp = "import sub\n" |
|
16 relimp = "from . import sub\n" |
|
17 deeprelimp = "from .... import sub\n" |
|
18 futimp = "from __future__ import absolute_import\n" |
|
19 |
|
20 reload_src = test_src+"""\ |
|
21 reloaded = True |
|
22 """ |
|
23 |
|
24 test_co = compile(test_src, "<???>", "exec") |
|
25 reload_co = compile(reload_src, "<???>", "exec") |
|
26 |
|
27 test2_oldabs_co = compile(absimp + test_src, "<???>", "exec") |
|
28 test2_newabs_co = compile(futimp + absimp + test_src, "<???>", "exec") |
|
29 test2_newrel_co = compile(relimp + test_src, "<???>", "exec") |
|
30 test2_deeprel_co = compile(deeprelimp + test_src, "<???>", "exec") |
|
31 test2_futrel_co = compile(futimp + relimp + test_src, "<???>", "exec") |
|
32 |
|
33 test_path = "!!!_test_!!!" |
|
34 |
|
35 |
|
36 class TestImporter: |
|
37 |
|
38 modules = { |
|
39 "hooktestmodule": (False, test_co), |
|
40 "hooktestpackage": (True, test_co), |
|
41 "hooktestpackage.sub": (True, test_co), |
|
42 "hooktestpackage.sub.subber": (True, test_co), |
|
43 "hooktestpackage.oldabs": (False, test2_oldabs_co), |
|
44 "hooktestpackage.newabs": (False, test2_newabs_co), |
|
45 "hooktestpackage.newrel": (False, test2_newrel_co), |
|
46 "hooktestpackage.sub.subber.subest": (True, test2_deeprel_co), |
|
47 "hooktestpackage.futrel": (False, test2_futrel_co), |
|
48 "sub": (False, test_co), |
|
49 "reloadmodule": (False, test_co), |
|
50 } |
|
51 |
|
52 def __init__(self, path=test_path): |
|
53 if path != test_path: |
|
54 # if out class is on sys.path_hooks, we must raise |
|
55 # ImportError for any path item that we can't handle. |
|
56 raise ImportError |
|
57 self.path = path |
|
58 |
|
59 def _get__path__(self): |
|
60 raise NotImplementedError |
|
61 |
|
62 def find_module(self, fullname, path=None): |
|
63 if fullname in self.modules: |
|
64 return self |
|
65 else: |
|
66 return None |
|
67 |
|
68 def load_module(self, fullname): |
|
69 ispkg, code = self.modules[fullname] |
|
70 mod = sys.modules.setdefault(fullname,imp.new_module(fullname)) |
|
71 mod.__file__ = "<%s>" % self.__class__.__name__ |
|
72 mod.__loader__ = self |
|
73 if ispkg: |
|
74 mod.__path__ = self._get__path__() |
|
75 exec code in mod.__dict__ |
|
76 return mod |
|
77 |
|
78 |
|
79 class MetaImporter(TestImporter): |
|
80 def _get__path__(self): |
|
81 return [] |
|
82 |
|
83 class PathImporter(TestImporter): |
|
84 def _get__path__(self): |
|
85 return [self.path] |
|
86 |
|
87 |
|
88 class ImportBlocker: |
|
89 """Place an ImportBlocker instance on sys.meta_path and you |
|
90 can be sure the modules you specified can't be imported, even |
|
91 if it's a builtin.""" |
|
92 def __init__(self, *namestoblock): |
|
93 self.namestoblock = dict.fromkeys(namestoblock) |
|
94 def find_module(self, fullname, path=None): |
|
95 if fullname in self.namestoblock: |
|
96 return self |
|
97 return None |
|
98 def load_module(self, fullname): |
|
99 raise ImportError, "I dare you" |
|
100 |
|
101 |
|
102 class ImpWrapper: |
|
103 |
|
104 def __init__(self, path=None): |
|
105 if path is not None and not os.path.isdir(path): |
|
106 raise ImportError |
|
107 self.path = path |
|
108 |
|
109 def find_module(self, fullname, path=None): |
|
110 subname = fullname.split(".")[-1] |
|
111 if subname != fullname and self.path is None: |
|
112 return None |
|
113 if self.path is None: |
|
114 path = None |
|
115 else: |
|
116 path = [self.path] |
|
117 try: |
|
118 file, filename, stuff = imp.find_module(subname, path) |
|
119 except ImportError: |
|
120 return None |
|
121 return ImpLoader(file, filename, stuff) |
|
122 |
|
123 |
|
124 class ImpLoader: |
|
125 |
|
126 def __init__(self, file, filename, stuff): |
|
127 self.file = file |
|
128 self.filename = filename |
|
129 self.stuff = stuff |
|
130 |
|
131 def load_module(self, fullname): |
|
132 mod = imp.load_module(fullname, self.file, self.filename, self.stuff) |
|
133 if self.file: |
|
134 self.file.close() |
|
135 mod.__loader__ = self # for introspection |
|
136 return mod |
|
137 |
|
138 |
|
139 class ImportHooksBaseTestCase(unittest.TestCase): |
|
140 |
|
141 def setUp(self): |
|
142 self.path = sys.path[:] |
|
143 self.meta_path = sys.meta_path[:] |
|
144 self.path_hooks = sys.path_hooks[:] |
|
145 sys.path_importer_cache.clear() |
|
146 self.modules_before = sys.modules.copy() |
|
147 |
|
148 def tearDown(self): |
|
149 sys.path[:] = self.path |
|
150 sys.meta_path[:] = self.meta_path |
|
151 sys.path_hooks[:] = self.path_hooks |
|
152 sys.path_importer_cache.clear() |
|
153 sys.modules.clear() |
|
154 sys.modules.update(self.modules_before) |
|
155 |
|
156 |
|
157 class ImportHooksTestCase(ImportHooksBaseTestCase): |
|
158 |
|
159 def doTestImports(self, importer=None): |
|
160 import hooktestmodule |
|
161 import hooktestpackage |
|
162 import hooktestpackage.sub |
|
163 import hooktestpackage.sub.subber |
|
164 self.assertEqual(hooktestmodule.get_name(), |
|
165 "hooktestmodule") |
|
166 self.assertEqual(hooktestpackage.get_name(), |
|
167 "hooktestpackage") |
|
168 self.assertEqual(hooktestpackage.sub.get_name(), |
|
169 "hooktestpackage.sub") |
|
170 self.assertEqual(hooktestpackage.sub.subber.get_name(), |
|
171 "hooktestpackage.sub.subber") |
|
172 if importer: |
|
173 self.assertEqual(hooktestmodule.__loader__, importer) |
|
174 self.assertEqual(hooktestpackage.__loader__, importer) |
|
175 self.assertEqual(hooktestpackage.sub.__loader__, importer) |
|
176 self.assertEqual(hooktestpackage.sub.subber.__loader__, importer) |
|
177 |
|
178 TestImporter.modules['reloadmodule'] = (False, test_co) |
|
179 import reloadmodule |
|
180 self.failIf(hasattr(reloadmodule,'reloaded')) |
|
181 |
|
182 TestImporter.modules['reloadmodule'] = (False, reload_co) |
|
183 reload(reloadmodule) |
|
184 self.failUnless(hasattr(reloadmodule,'reloaded')) |
|
185 |
|
186 import hooktestpackage.oldabs |
|
187 self.assertEqual(hooktestpackage.oldabs.get_name(), |
|
188 "hooktestpackage.oldabs") |
|
189 self.assertEqual(hooktestpackage.oldabs.sub, |
|
190 hooktestpackage.sub) |
|
191 |
|
192 import hooktestpackage.newrel |
|
193 self.assertEqual(hooktestpackage.newrel.get_name(), |
|
194 "hooktestpackage.newrel") |
|
195 self.assertEqual(hooktestpackage.newrel.sub, |
|
196 hooktestpackage.sub) |
|
197 |
|
198 import hooktestpackage.sub.subber.subest as subest |
|
199 self.assertEqual(subest.get_name(), |
|
200 "hooktestpackage.sub.subber.subest") |
|
201 self.assertEqual(subest.sub, |
|
202 hooktestpackage.sub) |
|
203 |
|
204 import hooktestpackage.futrel |
|
205 self.assertEqual(hooktestpackage.futrel.get_name(), |
|
206 "hooktestpackage.futrel") |
|
207 self.assertEqual(hooktestpackage.futrel.sub, |
|
208 hooktestpackage.sub) |
|
209 |
|
210 import sub |
|
211 self.assertEqual(sub.get_name(), "sub") |
|
212 |
|
213 import hooktestpackage.newabs |
|
214 self.assertEqual(hooktestpackage.newabs.get_name(), |
|
215 "hooktestpackage.newabs") |
|
216 self.assertEqual(hooktestpackage.newabs.sub, sub) |
|
217 |
|
218 def testMetaPath(self): |
|
219 i = MetaImporter() |
|
220 sys.meta_path.append(i) |
|
221 self.doTestImports(i) |
|
222 |
|
223 def testPathHook(self): |
|
224 sys.path_hooks.append(PathImporter) |
|
225 sys.path.append(test_path) |
|
226 self.doTestImports() |
|
227 |
|
228 def testBlocker(self): |
|
229 mname = "exceptions" # an arbitrary harmless builtin module |
|
230 if mname in sys.modules: |
|
231 del sys.modules[mname] |
|
232 sys.meta_path.append(ImportBlocker(mname)) |
|
233 try: |
|
234 __import__(mname) |
|
235 except ImportError: |
|
236 pass |
|
237 else: |
|
238 self.fail("'%s' was not supposed to be importable" % mname) |
|
239 |
|
240 def testImpWrapper(self): |
|
241 i = ImpWrapper() |
|
242 sys.meta_path.append(i) |
|
243 sys.path_hooks.append(ImpWrapper) |
|
244 mnames = ("colorsys", "urlparse", "distutils.core", "compiler.misc") |
|
245 for mname in mnames: |
|
246 parent = mname.split(".")[0] |
|
247 for n in sys.modules.keys(): |
|
248 if n.startswith(parent): |
|
249 del sys.modules[n] |
|
250 for mname in mnames: |
|
251 m = __import__(mname, globals(), locals(), ["__dummy__"]) |
|
252 m.__loader__ # to make sure we actually handled the import |
|
253 |
|
254 |
|
255 def test_main(): |
|
256 test_support.run_unittest(ImportHooksTestCase) |
|
257 |
|
258 if __name__ == "__main__": |
|
259 test_main() |