前回の Python の ast モジュール入門 (抽象構文木を辿る) では、ast モジュールのヘルパー関数を使って抽象構文木を辿ることを紹介しました。
抽象構文木を NodeVisitor で辿る
ast モジュールのヘルパー関数を使うのも1つの方法ですが、ast.NodeVisitor を使うと、もっとお手軽に抽象構文木を辿ることができます。やっていることはヘルパー関数を使うのと同じだというのを NodeVisitor の実装を見た方が分かりやすいのでその実装から紹介します。NodeVisitor は Visitor パターン というデザインパターンの1つです。
class NodeVisitor(object):
def visit(self, node):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
visit_NodeClassname が定義されていないときは ast.iter_fields で抽象構文木を辿る generic_visit が実行されます。抽象構文木のノードクラスは ast.AST を基底クラスに取るので、 isinstance(value, AST) でノードインスタンスかどうかを判定して再起的にトラバース (self.visit()) していくようになっています。
実際に使ってみましょう。NodeVisitor を継承したクラスを定義します。
>>> import ast
>>> source = """
... import sys
... def hello(s):
... print('hello {}'.format(s))
... hello('world')
... """
>>> class PrintNodeVisitor(ast.NodeVisitor):
... def visit(self, node):
... print(node)
... return super().visit(node)
...
>>> tree = ast.parse(source)
>>> PrintNodeVisitor().visit(tree)
<_ast.Module object at 0x10bec7b38>
<_ast.Import object at 0x10bec7b70>
<_ast.alias object at 0x10bec7ba8>
<_ast.FunctionDef object at 0x10bec7c18>
<_ast.arguments object at 0x10bec7c50>
<_ast.arg object at 0x10bec7c88>
<_ast.Expr object at 0x10bec7d30>
<_ast.Call object at 0x10bec7d68>
<_ast.Name object at 0x10bec7da0>
<_ast.Load object at 0x10bebe0f0>
<_ast.Call object at 0x10bec7e10>
<_ast.Attribute object at 0x10bec7e48>
<_ast.Str object at 0x10bec7e80>
<_ast.Load object at 0x10bebe0f0>
<_ast.Name object at 0x10bec7eb8>
<_ast.Load object at 0x10bebe0f0>
<_ast.Expr object at 0x10bec7f28>
<_ast.Call object at 0x10bec7f60>
<_ast.Name object at 0x10bec7f98>
<_ast.Load object at 0x10bebe0f0>
<_ast.Str object at 0x10bec7fd0>
簡単に抽象構文木を辿りながらノードを表示することができました。ある特定のノードをフックするには visit_NodeClassname のメソッドを定義します。
>>> class PrintExprNodePisitor(ast.NodeVisitor):
... def visit_Expr(self, node):
... print('Expr is visited')
... return node
...
>>> PrintExprNodePisitor().visit(tree)
Expr is visited
Expr is visited
PrintNodeVisitor での出力と見比べると Expr ノードを2回辿っていることが分かりますね。
抽象構文木を NodeTransformer で変更する
NodeVisitor はトラバースの途中でノードを変更できません。ノードを変更したい用途のときには ast.NodeTransformer を使います。
まずは簡単なソースコードから試してみましょう。
>>> import ast
>>> source = """
... print(s)
... """
>>> s = 'hello world'
>>> code = compile(source, '<string>', 'exec')
>>> exec(code)
hello world
ast.dump を使って、このソースコードがどういった抽象構文木に展開されるかを確認します。
>>> tree = ast.parse(source)
>>> ast.dump(tree)
"Module(body=[Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='s', ctx=Load())], keywords=[], starargs=None, kwargs=None))])"
これを眺めながら何か適当な例を考えてみます。
ここでは例として、出力される文字列を反転してみましょう。いろいろやり方はありそうですが、print 文を別の関数に置き換えるのを試してみます。
>>> class ReversePrintNodeTransformer(ast.NodeTransformer):
... def visit_Name(self, node):
... if node.id == 'print':
... name = ast.Name(id='reverse_print', ctx=ast.Load())
... return ast.copy_location(name, node)
... return node
...
>>> def reverse_print(s):
... print(''.join(reversed(s)))
...
>>> code = compile(ReversePrintNodeTransformer().visit(tree), '<string>', 'exec')
>>> exec(code)
dlrow olleh
>>> s = 'revese print'
>>> exec(code)
tnirp esever
それっぽく動きました。print 文が reverse_print 関数に置き換えられて実行されています。
ast.copy_location を使うと、元のノードから lineno と col_offset をコピーしてくれます。この2つの属性がないと AST オブジェクトを compile できません。
失敗する例を試してみましょう。
>>> from ast import *
>>> expression_without_attr = dump(parse('1 + 1', mode='eval'))
>>> expression_without_attr
'Expression(body=BinOp(left=Num(n=1), op=Add(), right=Num(n=1)))'
>>> code = compile(eval(expression_without_attr), '<string>', 'eval')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: required field "lineno" missing from expr
ここで ast.dump に属性も出力するように include_attributes=True を渡します。
>>> expression_with_attr = dump(parse('1 + 1', mode='eval'), include_attributes=True)
>>> expression_with_attr
'Expression(body=BinOp(left=Num(n=1, lineno=1, col_offset=0), op=Add(), right=Num(n=1, lineno=1, col_offset=4), lineno=1, col_offset=0))'
>>> code = compile(eval(expression_with_attr), '<string>', 'eval')
>>> eval(code)
2
lineno や col_offset を出力することで ast.dump の出力からそのまま AST オブジェクトを生成して (eval して) コンパイルすることもできました。
また別の解決方法として ast.fix_missing_locations を使う方法もあります。先ほどの expression_without_attr を使ってやってみましょう。
>>> code = compile(fix_missing_locations(eval(expression_without_attr)), '<string>', 'eval')
>>> eval(code)
2
今度は compile できましたね。fix_missing_locations のドキュメントによると、
生成されたノードに対しこれらを埋めて回るのはどちらかというと退屈な作業なので、このヘルパーが再帰的に二つの属性がセットされていないものに親ノードと同じ値をセットしていきます。
と、自動的にセットしてくれるそうです。
抽象構文木をいじる (NodeTransformer を使う) とき
実際に抽象構文木をいじって解決したい課題を見つけるのもちょっと難しいのですが、いくつか見つけたものを紹介します。
データとして Python のコードを扱うとき、つまりは扱わないと難しいなにかがあるようなときに思い出すと便利なときもあるかもしれません。