gh-123344: Add missing ast optimizations for PEP 696 (#123377)

Co-authored-by: Kirill Podoprigora <kirill.bast9@mail.ru>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
This commit is contained in:
Bogdan Romanyuk 2024-08-28 16:38:56 +03:00 committed by GitHub
parent 9e108b8719
commit be083cee34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 99 additions and 9 deletions

View File

@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])
def wrap_for(self, for_statement):
return ast.Module(body=[for_statement])
def wrap_statement(self, statement):
return ast.Module(body=[statement])
def assert_ast(self, code, non_optimized_target, optimized_target):
non_optimized_tree = ast.parse(code, optimize=-1)
@ -3090,16 +3090,16 @@ class ASTOptimiziationTests(unittest.TestCase):
f"{ast.dump(optimized_tree)}",
)
def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)
def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()
def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
return ast.BinOp(left=left, op=self.binop[operand], right=right)
for op in operators:
result_code = code % op
non_optimized_target = self.wrap_expr(create_binop(op))
non_optimized_target = self.wrap_expr(self.create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
with self.subTest(
@ -3111,7 +3111,7 @@ class ASTOptimiziationTests(unittest.TestCase):
# Multiplication of constant tuples must be folded
code = "(1,) * 3"
non_optimized_target = self.wrap_expr(create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))
self.assert_ast(code, non_optimized_target, optimized_target)
@ -3222,12 +3222,12 @@ class ASTOptimiziationTests(unittest.TestCase):
]
for left, right, ast_cls, optimized_iter in braces:
non_optimized_target = self.wrap_for(ast.For(
non_optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast_cls(elts=[ast.Constant(1)]),
body=[ast.Pass()]
))
optimized_target = self.wrap_for(ast.For(
optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast.Constant(value=optimized_iter),
body=[ast.Pass()]
@ -3245,6 +3245,92 @@ class ASTOptimiziationTests(unittest.TestCase):
self.assert_ast(code, non_optimized_target, optimized_target)
def test_folding_type_param_in_function_def(self):
code = "def foo[%s = 1 + 1](): pass"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
)
)
non_optimized_target = self.wrap_statement(
ast.FunctionDef(
name='foo',
args=ast.arguments(),
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_type_param_in_class_def(self):
code = "class foo[%s = 1 + 1]: pass"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=ast.Constant(2))]
)
)
non_optimized_target = self.wrap_statement(
ast.ClassDef(
name='foo',
body=[ast.Pass()],
type_params=[type_param(name=name, default_value=unoptimized_binop)]
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
def test_folding_type_param_in_type_alias(self):
code = "type foo[%s = 1 + 1] = 1"
unoptimized_binop = self.create_binop("+")
unoptimized_type_params = [
("T", "T", ast.TypeVar),
("**P", "P", ast.ParamSpec),
("*Ts", "Ts", ast.TypeVarTuple),
]
for type, name, type_param in unoptimized_type_params:
result_code = code % type
optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=ast.Constant(2))],
value=ast.Constant(value=1),
)
)
non_optimized_target = self.wrap_statement(
ast.TypeAlias(
name=ast.Name(id='foo', ctx=ast.Store()),
type_params=[type_param(name=name, default_value=unoptimized_binop)],
value=ast.Constant(value=1),
)
)
self.assert_ast(result_code, non_optimized_target, optimized_target)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1 @@
Add AST optimizations for type parameter defaults.

View File

@ -1087,10 +1087,13 @@ astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
switch (node_->kind) {
case TypeVar_kind:
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
break;
case ParamSpec_kind:
CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
break;
case TypeVarTuple_kind:
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
break;
}
return 1;