|
1 """Fixer for import statements. |
|
2 If spam is being imported from the local directory, this import: |
|
3 from spam import eggs |
|
4 Becomes: |
|
5 from .spam import eggs |
|
6 |
|
7 And this import: |
|
8 import spam |
|
9 Becomes: |
|
10 from . import spam |
|
11 """ |
|
12 |
|
13 # Local imports |
|
14 from .. import fixer_base |
|
15 from os.path import dirname, join, exists, pathsep |
|
16 from ..fixer_util import FromImport |
|
17 |
|
18 class FixImport(fixer_base.BaseFix): |
|
19 |
|
20 PATTERN = """ |
|
21 import_from< type='from' imp=any 'import' ['('] any [')'] > |
|
22 | |
|
23 import_name< type='import' imp=any > |
|
24 """ |
|
25 |
|
26 def transform(self, node, results): |
|
27 imp = results['imp'] |
|
28 |
|
29 if unicode(imp).startswith('.'): |
|
30 # Already a new-style import |
|
31 return |
|
32 |
|
33 if not probably_a_local_import(unicode(imp), self.filename): |
|
34 # I guess this is a global import -- skip it! |
|
35 return |
|
36 |
|
37 if results['type'].value == 'from': |
|
38 # Some imps are top-level (eg: 'import ham') |
|
39 # some are first level (eg: 'import ham.eggs') |
|
40 # some are third level (eg: 'import ham.eggs as spam') |
|
41 # Hence, the loop |
|
42 while not hasattr(imp, 'value'): |
|
43 imp = imp.children[0] |
|
44 imp.value = "." + imp.value |
|
45 node.changed() |
|
46 else: |
|
47 new = FromImport('.', getattr(imp, 'content', None) or [imp]) |
|
48 new.set_prefix(node.get_prefix()) |
|
49 node = new |
|
50 return node |
|
51 |
|
52 def probably_a_local_import(imp_name, file_path): |
|
53 # Must be stripped because the right space is included by the parser |
|
54 imp_name = imp_name.split('.', 1)[0].strip() |
|
55 base_path = dirname(file_path) |
|
56 base_path = join(base_path, imp_name) |
|
57 # If there is no __init__.py next to the file its not in a package |
|
58 # so can't be a relative import. |
|
59 if not exists(join(dirname(base_path), '__init__.py')): |
|
60 return False |
|
61 for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: |
|
62 if exists(base_path + ext): |
|
63 return True |
|
64 return False |