cpython/Parser/asdl.py
Benjamin Peterson 481ae50ccd construct fields in the right order (closes #15517)
Patch from Taihyun Hwang.
2012-07-31 21:41:56 -07:00

420 lines
11 KiB
Python

"""An implementation of the Zephyr Abstract Syntax Definition Language.
See http://asdl.sourceforge.net/ and
http://www.cs.princeton.edu/research/techreps/TR-554-97
Only supports top level module decl, not view. I'm guessing that view
is intended to support the browser and I'm not interested in the
browser.
Changes for Python: Add support for module versions
"""
import os
import sys
import traceback
import spark
def output(string):
sys.stdout.write(string + "\n")
class Token(object):
# spark seems to dispatch in the parser based on a token's
# type attribute
def __init__(self, type, lineno):
self.type = type
self.lineno = lineno
def __str__(self):
return self.type
def __repr__(self):
return str(self)
class Id(Token):
def __init__(self, value, lineno):
self.type = 'Id'
self.value = value
self.lineno = lineno
def __str__(self):
return self.value
class String(Token):
def __init__(self, value, lineno):
self.type = 'String'
self.value = value
self.lineno = lineno
class ASDLSyntaxError(Exception):
def __init__(self, lineno, token=None, msg=None):
self.lineno = lineno
self.token = token
self.msg = msg
def __str__(self):
if self.msg is None:
return "Error at '%s', line %d" % (self.token, self.lineno)
else:
return "%s, line %d" % (self.msg, self.lineno)
class ASDLScanner(spark.GenericScanner, object):
def tokenize(self, input):
self.rv = []
self.lineno = 1
super(ASDLScanner, self).tokenize(input)
return self.rv
def t_id(self, s):
r"[\w\.]+"
# XXX doesn't distinguish upper vs. lower, which is
# significant for ASDL.
self.rv.append(Id(s, self.lineno))
def t_string(self, s):
r'"[^"]*"'
self.rv.append(String(s, self.lineno))
def t_xxx(self, s): # not sure what this production means
r"<="
self.rv.append(Token(s, self.lineno))
def t_punctuation(self, s):
r"[\{\}\*\=\|\(\)\,\?\:]"
self.rv.append(Token(s, self.lineno))
def t_comment(self, s):
r"\-\-[^\n]*"
pass
def t_newline(self, s):
r"\n"
self.lineno += 1
def t_whitespace(self, s):
r"[ \t]+"
pass
def t_default(self, s):
r" . +"
raise ValueError("unmatched input: %r" % s)
class ASDLParser(spark.GenericParser, object):
def __init__(self):
super(ASDLParser, self).__init__("module")
def typestring(self, tok):
return tok.type
def error(self, tok):
raise ASDLSyntaxError(tok.lineno, tok)
def p_module_0(self, info):
" module ::= Id Id { } "
module, name, _0, _1 = info
if module.value != "module":
raise ASDLSyntaxError(module.lineno,
msg="expected 'module', found %s" % module)
return Module(name, None)
def p_module(self, info):
" module ::= Id Id { definitions } "
module, name, _0, definitions, _1 = info
if module.value != "module":
raise ASDLSyntaxError(module.lineno,
msg="expected 'module', found %s" % module)
return Module(name, definitions)
def p_definition_0(self, definition):
" definitions ::= definition "
return definition[0]
def p_definition_1(self, definitions):
" definitions ::= definition definitions "
return definitions[0] + definitions[1]
def p_definition(self, info):
" definition ::= Id = type "
id, _, type = info
return [Type(id, type)]
def p_type_0(self, product):
" type ::= product "
return product[0]
def p_type_1(self, sum):
" type ::= sum "
return Sum(sum[0])
def p_type_2(self, info):
" type ::= sum Id ( fields ) "
sum, id, _0, attributes, _1 = info
if id.value != "attributes":
raise ASDLSyntaxError(id.lineno,
msg="expected attributes, found %s" % id)
return Sum(sum, attributes)
def p_product(self, info):
" product ::= ( fields ) "
_0, fields, _1 = info
return Product(fields)
def p_sum_0(self, constructor):
" sum ::= constructor "
return [constructor[0]]
def p_sum_1(self, info):
" sum ::= constructor | sum "
constructor, _, sum = info
return [constructor] + sum
def p_sum_2(self, info):
" sum ::= constructor | sum "
constructor, _, sum = info
return [constructor] + sum
def p_constructor_0(self, id):
" constructor ::= Id "
return Constructor(id[0])
def p_constructor_1(self, info):
" constructor ::= Id ( fields ) "
id, _0, fields, _1 = info
return Constructor(id, fields)
def p_fields_0(self, field):
" fields ::= field "
return [field[0]]
def p_fields_1(self, info):
" fields ::= fields , field "
fields, _, field = info
return fields + [field]
def p_field_0(self, type_):
" field ::= Id "
return Field(type_[0])
def p_field_1(self, info):
" field ::= Id Id "
type, name = info
return Field(type, name)
def p_field_2(self, info):
" field ::= Id * Id "
type, _, name = info
return Field(type, name, seq=True)
def p_field_3(self, info):
" field ::= Id ? Id "
type, _, name = info
return Field(type, name, opt=True)
def p_field_4(self, type_):
" field ::= Id * "
return Field(type_[0], seq=True)
def p_field_5(self, type_):
" field ::= Id ? "
return Field(type[0], opt=True)
builtin_types = ("identifier", "string", "bytes", "int", "object")
# below is a collection of classes to capture the AST of an AST :-)
# not sure if any of the methods are useful yet, but I'm adding them
# piecemeal as they seem helpful
class AST(object):
pass # a marker class
class Module(AST):
def __init__(self, name, dfns):
self.name = name
self.dfns = dfns
self.types = {} # maps type name to value (from dfns)
for type in dfns:
self.types[type.name.value] = type.value
def __repr__(self):
return "Module(%s, %s)" % (self.name, self.dfns)
class Type(AST):
def __init__(self, name, value):
self.name = name
self.value = value
def __repr__(self):
return "Type(%s, %s)" % (self.name, self.value)
class Constructor(AST):
def __init__(self, name, fields=None):
self.name = name
self.fields = fields or []
def __repr__(self):
return "Constructor(%s, %s)" % (self.name, self.fields)
class Field(AST):
def __init__(self, type, name=None, seq=False, opt=False):
self.type = type
self.name = name
self.seq = seq
self.opt = opt
def __repr__(self):
if self.seq:
extra = ", seq=True"
elif self.opt:
extra = ", opt=True"
else:
extra = ""
if self.name is None:
return "Field(%s%s)" % (self.type, extra)
else:
return "Field(%s, %s%s)" % (self.type, self.name, extra)
class Sum(AST):
def __init__(self, types, attributes=None):
self.types = types
self.attributes = attributes or []
def __repr__(self):
if self.attributes is None:
return "Sum(%s)" % self.types
else:
return "Sum(%s, %s)" % (self.types, self.attributes)
class Product(AST):
def __init__(self, fields):
self.fields = fields
def __repr__(self):
return "Product(%s)" % self.fields
class VisitorBase(object):
def __init__(self, skip=False):
self.cache = {}
self.skip = skip
def visit(self, object, *args):
meth = self._dispatch(object)
if meth is None:
return
try:
meth(object, *args)
except Exception:
output("Error visiting" + repr(object))
output(str(sys.exc_info()[1]))
traceback.print_exc()
# XXX hack
if hasattr(self, 'file'):
self.file.flush()
os._exit(1)
def _dispatch(self, object):
assert isinstance(object, AST), repr(object)
klass = object.__class__
meth = self.cache.get(klass)
if meth is None:
methname = "visit" + klass.__name__
if self.skip:
meth = getattr(self, methname, None)
else:
meth = getattr(self, methname)
self.cache[klass] = meth
return meth
class Check(VisitorBase):
def __init__(self):
super(Check, self).__init__(skip=True)
self.cons = {}
self.errors = 0
self.types = {}
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
def visitType(self, type):
self.visit(type.value, str(type.name))
def visitSum(self, sum, name):
for t in sum.types:
self.visit(t, name)
def visitConstructor(self, cons, name):
key = str(cons.name)
conflict = self.cons.get(key)
if conflict is None:
self.cons[key] = name
else:
output("Redefinition of constructor %s" % key)
output("Defined in %s and %s" % (conflict, name))
self.errors += 1
for f in cons.fields:
self.visit(f, key)
def visitField(self, field, name):
key = str(field.type)
l = self.types.setdefault(key, [])
l.append(name)
def visitProduct(self, prod, name):
for f in prod.fields:
self.visit(f, name)
def check(mod):
v = Check()
v.visit(mod)
for t in v.types:
if t not in mod.types and not t in builtin_types:
v.errors += 1
uses = ", ".join(v.types[t])
output("Undefined type %s, used in %s" % (t, uses))
return not v.errors
def parse(file):
scanner = ASDLScanner()
parser = ASDLParser()
buf = open(file).read()
tokens = scanner.tokenize(buf)
try:
return parser.parse(tokens)
except ASDLSyntaxError:
err = sys.exc_info()[1]
output(str(err))
lines = buf.split("\n")
output(lines[err.lineno - 1]) # lines starts at 0, files at 1
if __name__ == "__main__":
import glob
import sys
if len(sys.argv) > 1:
files = sys.argv[1:]
else:
testdir = "tests"
files = glob.glob(testdir + "/*.asdl")
for file in files:
output(file)
mod = parse(file)
if not mod:
break
output("module", mod.name)
output(len(mod.dfns), "definitions")
if not check(mod):
output("Check failed")
else:
for dfn in mod.dfns:
output(dfn.type)