[Python-checkins] r61626 - in sandbox/trunk/2to3/lib2to3: fixes/fix_import.py tests/test_fixers.py

david.wolever python-checkins at python.org
Wed Mar 19 17:19:17 CET 2008


Author: david.wolever
Date: Wed Mar 19 17:19:16 2008
New Revision: 61626

Added:
   sandbox/trunk/2to3/lib2to3/fixes/fix_import.py
Modified:
   sandbox/trunk/2to3/lib2to3/tests/test_fixers.py
Log:
Added fixer for implicit local imports.  See #2414.

Added: sandbox/trunk/2to3/lib2to3/fixes/fix_import.py
==============================================================================
--- (empty file)
+++ sandbox/trunk/2to3/lib2to3/fixes/fix_import.py	Wed Mar 19 17:19:16 2008
@@ -0,0 +1,55 @@
+"""Fixer for import statements.
+If spam is being imported from the local directory, this import:
+    from spam import eggs
+Becomes:
+    from .spam import eggs
+
+And this import:
+    import spam
+Becomes:
+    import .spam
+"""
+
+# Local imports
+from . import basefix
+from os.path import dirname, join, exists, pathsep
+
+class FixImport(basefix.BaseFix):
+
+    PATTERN = """
+    import_from< 'from' imp=any 'import' any >
+    |
+    import_name< 'import' imp=any >
+    """
+
+    def transform(self, node, results):
+        imp = results['imp']
+
+        if unicode(imp).startswith('.'):
+            # Already a new-style import
+            return
+
+        if not probably_a_local_import(unicode(imp), self.filename):
+            # I guess this is a global import -- skip it!
+            return
+
+        # Some imps are top-level (eg: 'import ham')
+        # some are first level (eg: 'import ham.eggs')
+        # some are third level (eg: 'import ham.eggs as spam')
+        # Hence, the loop
+        while not hasattr(imp, 'value'):
+            imp = imp.children[0]
+
+        imp.value = "." + imp.value
+        node.changed()
+        return node
+
+def probably_a_local_import(imp_name, file_path):
+    # Must be stripped because the right space is included by the parser
+    imp_name = imp_name.split('.', 1)[0].strip()
+    base_path = dirname(file_path)
+    base_path = join(base_path, imp_name) 
+    for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']:
+        if exists(base_path + ext):
+            return True
+    return False

Modified: sandbox/trunk/2to3/lib2to3/tests/test_fixers.py
==============================================================================
--- sandbox/trunk/2to3/lib2to3/tests/test_fixers.py	(original)
+++ sandbox/trunk/2to3/lib2to3/tests/test_fixers.py	Wed Mar 19 17:19:16 2008
@@ -10,6 +10,7 @@
 
 # Python imports
 import unittest
+from os.path import dirname, pathsep
 
 # Local imports
 from .. import pygram
@@ -28,6 +29,7 @@
         options = Options(fix=[self.fixer], print_function=False)
         self.refactor = refactor.RefactoringTool(options)
         self.fixer_log = []
+        self.filename = "<string>"
 
         for order in (self.refactor.pre_order, self.refactor.post_order):
             for fixer in order:
@@ -36,7 +38,7 @@
     def _check(self, before, after):
         before = support.reformat(before)
         after = support.reformat(after)
-        tree = self.refactor.refactor_string(before, "<string>")
+        tree = self.refactor.refactor_string(before, self.filename)
         self.failUnlessEqual(after, str(tree))
         return tree
 
@@ -60,7 +62,6 @@
         if not ignore_warnings:
             self.failUnlessEqual(self.fixer_log, [])
 
-
 class Test_ne(FixerTestCase):
     fixer = "ne"
 
@@ -412,7 +413,6 @@
         a = """print(file=sys.stderr)"""
         self.check(b, a)
 
-
 class Test_exec(FixerTestCase):
     fixer = "exec"
 
@@ -464,7 +464,6 @@
         s = """exec(code, ns1, ns2)"""
         self.unchanged(s)
 
-
 class Test_repr(FixerTestCase):
     fixer = "repr"
 
@@ -666,7 +665,6 @@
                 pass"""
         self.unchanged(s)
 
-
 class Test_raise(FixerTestCase):
     fixer = "raise"
 
@@ -789,7 +787,6 @@
                     b = 6"""
         self.check(b, a)
 
-
 class Test_throw(FixerTestCase):
     fixer = "throw"
 
@@ -915,7 +912,6 @@
                     b = 6"""
         self.check(b, a)
 
-
 class Test_long(FixerTestCase):
     fixer = "long"
 
@@ -961,7 +957,6 @@
         a = """x =   int(  x  )"""
         self.check(b, a)
 
-
 class Test_dict(FixerTestCase):
     fixer = "dict"
 
@@ -1171,7 +1166,6 @@
         a = """for i in range(10):\n    j=i"""
         self.check(b, a)
 
-
 class Test_raw_input(FixerTestCase):
     fixer = "raw_input"
 
@@ -1204,7 +1198,6 @@
         a = """x = input(foo(a) + 6)"""
         self.check(b, a)
 
-
 class Test_funcattrs(FixerTestCase):
     fixer = "funcattrs"
 
@@ -1231,7 +1224,6 @@
             s = "f(foo.__%s__.foo)" % attr
             self.unchanged(s)
 
-
 class Test_xreadlines(FixerTestCase):
     fixer = "xreadlines"
 
@@ -1274,7 +1266,6 @@
         s = "foo(xreadlines)"
         self.unchanged(s)
 
-
 class Test_imports(FixerTestCase):
     fixer = "imports"
 
@@ -1352,7 +1343,6 @@
                     """ % (new, member, member, member)
                 self.check(b, a)
 
-
 class Test_input(FixerTestCase):
     fixer = "input"
 
@@ -1400,7 +1390,6 @@
         a = """x = eval(input(foo(5) + 9))"""
         self.check(b, a)
 
-
 class Test_tuple_params(FixerTestCase):
     fixer = "tuple_params"
 
@@ -1620,7 +1609,6 @@
             s = "f(foo.__%s__.foo)" % attr
             self.unchanged(s)
 
-
 class Test_next(FixerTestCase):
     fixer = "next"
 
@@ -2250,7 +2238,6 @@
                 """ % (mod, new, mod, new)
             self.check(b, a)
 
-
 class Test_unicode(FixerTestCase):
     fixer = "unicode"
 
@@ -2904,7 +2891,6 @@
             """
         self.unchanged(s)
 
-
 class Test_basestring(FixerTestCase):
     fixer = "basestring"
 
@@ -2913,7 +2899,6 @@
         a = """isinstance(x, str)"""
         self.check(b, a)
 
-
 class Test_buffer(FixerTestCase):
     fixer = "buffer"
 
@@ -2975,6 +2960,78 @@
         a = """    itertools.filterfalse(a, b)"""
         self.check(b, a)
 
+class Test_import(FixerTestCase):
+    fixer = "import"
+
+    def setUp(self):
+        FixerTestCase.setUp(self)
+        # Need to replace fix_import's isfile and isdir method
+        # so we can check that it's doing the right thing
+        self.files_checked = []
+        self.always_exists = True
+        def fake_exists(name):
+            self.files_checked.append(name)
+            return self.always_exists
+
+        from ..fixes import fix_import
+        fix_import.exists = fake_exists
+
+    def check_both(self, b, a):
+        self.always_exists = True
+        FixerTestCase.check(self, b, a)
+        self.always_exists = False
+        FixerTestCase.unchanged(self, b)
+
+    def test_files_checked(self):
+        def p(path):
+            # Takes a unix path and returns a path with correct separators
+            return pathsep.join(path.split("/"))
+
+        self.always_exists = False
+        expected_extensions = ('.py', pathsep, '.pyc', '.so', '.sl', '.pyd')
+        names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py"))
+
+        for name in names_to_test:
+            self.files_checked = []
+            self.filename = name
+            self.unchanged("import jam")
+
+            if dirname(name): name = dirname(name) + '/jam'
+            else:             name = 'jam'
+            expected_checks = set(name + ext for ext in expected_extensions)
+
+            self.failUnlessEqual(set(self.files_checked), expected_checks)
+
+    def test_from(self):
+        b = "from foo import bar"
+        a = "from .foo import bar"
+        self.check_both(b, a)
+
+    def test_dotted_from(self):
+        b = "from green.eggs import ham"
+        a = "from .green.eggs import ham"
+        self.check_both(b, a)
+
+    def test_from_as(self):
+        b = "from green.eggs import ham as spam"
+        a = "from .green.eggs import ham as spam"
+        self.check_both(b, a)
+
+    def test_import(self):
+        b = "import foo"
+        a = "import .foo"
+        self.check_both(b, a)
+
+    def test_dotted_import(self):
+        b = "import foo.bar"
+        a = "import .foo.bar"
+        self.check_both(b, a)
+
+    def test_dotted_import_as(self):
+        b = "import foo.bar as bang"
+        a = "import .foo.bar as bang"
+        self.check_both(b, a)
+
 
 if __name__ == "__main__":
     import __main__


More information about the Python-checkins mailing list