Modify code with AST

Peter Otten __peter__ at web.de
Wed Jun 12 05:00:44 EDT 2013


Ronny Mandal wrote:

> Hello,
> 
> I am trying to write a script which will parse a code segment (with
> ast.parse()), locate the correct function/method node (by name) in the
> resulting tree and replace this function (node) with another function
> (node), e.g.:
> 
> MyMod1.py:
> 
> class FooBar():
>   def Foo(self): #I want to replace this and only this
>     return 1
> 
>   def Bar(self):
>     return 2
> 
> Here is the parser-class:
> 
> class FindAndTransform(NodeTransformer):
>   """Visit the function and check name"""
>   def visit_FunctionDef(self, node):
>     if node.name == 'Foo': #Only replace if name is "Foo"
>       #Create a new function and assign it to node
>       node = parse('''
> def add(n, m):
>   return n + m
> ''')
>       return node
> 
> When I run the parser on MyMod1.py and generate code (with codegen), the
> output is:
> 
> class FooBar():
>   def add(n, m):
>     return n + m
> 
> i.e. both methods are replaced. It seems like "node" in the parser
> contains all method bodies of class FooBar, not only Foo. When ran through
> a debugger, it iterates both methods. What I really wanted to do, was to
> replace only one method (Foo) and leave the other untouched.
> 
> I hope this was understandable conveyed.

I think the main problem is that you have to return the unchanged node (you 
return None which might be an indentation accident). I also had to take the 
add() FunctionDef out of the enclosing Module. So (I don't have codegen or 
is it part of the stdlib?):

import ast

class FindAndTransform(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        if node.name == 'Foo':
            node = ast.parse('''
def add(n, m):
  return n + m
''').body[0]
        return node

if __name__ == "__main__":
    orig = """
class FooBar():
  def Foo(self): #I want to replace this and only this
    return 1

  def Bar(self):
    return 2
"""

    p = ast.parse(orig)
    q = FindAndTransform().visit(p)
    qq = compile(q, "<nofile>", "exec")
    exec(qq)
    assert {n for n in dir(FooBar) if not n.startswith("_")} == {"Bar", 
"add"}





More information about the Python-list mailing list