Module bento.graph.transforms
Expand source code
#
# Bentobox
# SDK - Graph
# Transforms
#
import gast
from typing import Callable
from gast import (
AST,
Constant,
Assign,
List,
Tuple,
IfExp,
If,
Store,
Load,
ListComp,
comprehension,
)
from bento.graph.plotter import Plotter
from bento.graph.ast import (
parse_ast,
name_ast,
import_from_ast,
assign_ast,
call_func_ast,
load_ast_module,
wrap_func_ast,
wrap_block_ast,
FuncASTTransform,
)
def transform_build_graph(ast: AST) -> AST:
"""Transforms convert function to build graph function
Requires the `analyze_convert_fn()` to analyze the AST first.
Removes the convert function's annotations and changes the name to `build_graph`
Args:
ast: AST with the convert function to transform
Returns:
The given ast with the convert function transformed to build graph function
"""
fn_ast = ast.convert_fn
fn_ast.decorator_list, fn_ast.name = [], "build_graph"
return ast
def transform_ternary(ast: AST) -> AST:
"""Transforms ternary conditional to plot Switch Nodes on computational graph.
Requires the `analyze_convert_fn()` to analyze the AST first.
Transforms ternary statements in the form `a if b else c` into a call to the
Graph Plotter to plot a Switch Node `g.switch(b, a, c)` on the computation graph.
Args:
ast: AST to transform ternary into the plotting of a Switch Node.
Returns:
The given ast with ternary statements transformed into the plotting of a Switch Node.
"""
def do_transform(ternary_ast: AST) -> AST:
# filter out non-ternary expressions
if not isinstance(ternary_ast, IfExp):
return ternary_ast
# obtain AST of calling the plotter to plotting a switch node
plot_switch_fn = parse_ast(Plotter.switch).body[0]
return call_func_ast(
fn_name=plot_switch_fn.name,
args={
"condition": ternary_ast.test,
"true": ternary_ast.body,
"false": ternary_ast.orelse,
},
attr_parent=ast.convert_fn.plotter_name,
)
ternary_transform = FuncASTTransform(do_transform)
# apply ternary transform to AST
ast = ternary_transform.visit(ast)
return ast
def transform_ifelse(ast: AST) -> AST:
"""Transforms if/elif/else statements to plot Switch Nodes on computational graph.
Requires the `analyze_activity()` to analyze the AST first.
Transforms if/elif/else statements to a functions that using the Graph Plotter
trace the operations performed on each conditional branch.
Collects operations for each conditional branch using the Graph plotter
and combines them into a conditional Switch node.
Example:
if a:
x = y
z = 1
elif b:
x = m
z = 2
else:
x = n
z = 3
# should be transformed into:
def __if_block(y, m, n):
x = y
z = 1
return x, z
def __else_block(b, m, n):
def __if_block(m, n):
x = m
z = 2
return x, z
def __else_block(m, n):
x = n
z = 2
return x, z
from copy import deepcopy
__if_condition = deepcopy(b)
__if_outputs = if_block(m=deepcopy(m), n=deepcopy(n))
__else_outputs = else_block(m=deepcopy(m), n=deepcopy(n))
x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)]
return x, z
from copy import deepcopy
__if_condition = deepcopy(a)
__if_outputs = __if_block(b=deepcopy(b), m=deepcopy(m), n=deepcopy(n))
__else_outputs = __else_block(b=deepcopy(n), m=deepcopy(m), n=deepcopy(n))
x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)]
# which will evaluate to:
x = g.switch(a, y, g.switch(b, m, n))
z = g.switch(a, 1, g.switch(b, 2, 3))
Note:
Defining a symbol inside the if statement requires that means that
that symbol has be defined in all conditional branches.
Since all branches of the if statement are traced, symbols used in the
if statement must be `deepcopy()`able. `deepcopy()` is used to ensure
that variables in different branches don't interfere with each other.
Args:
ast: AST to transform if else statements into the plotting of a Switch Node.
Returns:
The given ast with if else statements to be transformed into a the plotting of a Switch Node.
"""
def do_transform(ifelse_ast: AST) -> AST:
# filter out non-ifelse statements
if not isinstance(ifelse_ast, If):
return ifelse_ast
# convert ifelse condition branches into functions with the arguments
# set to the names of the base input symbols and return values set to output symbols.
# base symbols are use to generate the arguments as the full symbol might be qualified
# ie A.x which is not a valid argument name: https://github.com/joeltio/bento-box/issues/37
args = list(ifelse_ast.base_in_syms.keys())
returns = list(ifelse_ast.output_syms.keys())
fn_asts = [
wrap_func_ast(
name=name,
args=args,
block=block,
returns=returns,
# zip() requires the returned outputs to be iterable
return_tuple=True,
)
for name, block in zip(
["__if_block", "__else_block"], [ifelse_ast.body, ifelse_ast.orelse]
)
]
# deepcopy the condition before tracing the if/else block functions to
# prevent side effects tracing from interfering with the condition.
condition_ast = name_ast("__if_condition")
eval_condition_ast = assign_ast(
targets=[condition_ast],
values=[call_func_ast("deepcopy", args=[ifelse_ast.test])],
)
# call if/else block functions to trace results of evaluating each branch
# of the conditional if/else block functions have arguments with the same
# names as symbols we have to pass in.
# deepcopy to prevent input symbols from being passed by reference and
# causing interference between branches https://github.com/joeltio/bento-box/issues/39
import_deepcopy_ast = import_from_ast(module="copy", names=["deepcopy"])
call_args = {
a: call_func_ast(
fn_name="deepcopy",
args=[name_ast(a)],
)
for a in args
}
branch_outputs = [name_ast(n) for n in ["__if_outputs", "__else_outputs"]]
call_fn_asts = [
assign_ast(
targets=[target],
values=[call_func_ast(fn_ast.name, args=call_args)],
)
for target, fn_ast in zip(branch_outputs, fn_asts)
]
# create switch nodes for each output symbol via list comprehension
plot_switch_fn = parse_ast(Plotter.switch).body[0]
# g.switch(test, if_out, else_out)
call_switch_ast = call_func_ast(
fn_name=plot_switch_fn.name,
args={
"condition": condition_ast,
"true": name_ast("if_out"),
"false": name_ast("else_out"),
},
attr_parent=ast.convert_fn.plotter_name,
)
# (symbol, ...) = [g.switch(...) for if_out, else_out in zip(if_outputs, else_outputs)]
switch_asts = assign_ast(
targets=[name_ast(r, ctx=Store()) for r in returns],
values=[
ListComp(
elt=call_switch_ast,
generators=[
comprehension(
target=Tuple(
elts=[
name_ast("if_out"),
name_ast("else_out"),
],
ctx=Load(),
),
iter=call_func_ast(
fn_name="zip",
args=branch_outputs,
),
ifs=[],
is_async=False,
)
],
)
],
force_tuple=True,
)
# wrap transformed code block as single AST node
return wrap_block_ast(
block=fn_asts
+ [import_deepcopy_ast, eval_condition_ast]
+ call_fn_asts
+ [switch_asts],
)
ifelse_transform = FuncASTTransform(do_transform)
# apply ifelse transform to AST
ast = ifelse_transform.visit(ast)
mod = load_ast_module(ast)
return ast
Functions
def transform_build_graph(ast: _ast.AST) ‑> _ast.AST
-
Transforms convert function to build graph function
Requires the
analyze_convert_fn()
to analyze the AST first. Removes the convert function's annotations and changes the name tobuild_graph
Args
ast
- AST with the convert function to transform
Returns
The given ast with the convert function transformed to build graph function
Expand source code
def transform_build_graph(ast: AST) -> AST: """Transforms convert function to build graph function Requires the `analyze_convert_fn()` to analyze the AST first. Removes the convert function's annotations and changes the name to `build_graph` Args: ast: AST with the convert function to transform Returns: The given ast with the convert function transformed to build graph function """ fn_ast = ast.convert_fn fn_ast.decorator_list, fn_ast.name = [], "build_graph" return ast
def transform_ifelse(ast: _ast.AST) ‑> _ast.AST
-
Transforms if/elif/else statements to plot Switch Nodes on computational graph.
Requires the
analyze_activity()
to analyze the AST first. Transforms if/elif/else statements to a functions that using the Graph Plotter trace the operations performed on each conditional branch.Collects operations for each conditional branch using the Graph plotter and combines them into a conditional Switch node.
Example: if a: x = y z = 1 elif b: x = m z = 2 else: x = n z = 3 # should be transformed into: def __if_block(y, m, n): x = y z = 1 return x, z def __else_block(b, m, n): def __if_block(m, n): x = m z = 2 return x, z def __else_block(m, n): x = n z = 2 return x, z from copy import deepcopy __if_condition = deepcopy(b) __if_outputs = if_block(m=deepcopy(m), n=deepcopy(n)) __else_outputs = else_block(m=deepcopy(m), n=deepcopy(n)) x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)] return x, z from copy import deepcopy __if_condition = deepcopy(a) __if_outputs = __if_block(b=deepcopy(b), m=deepcopy(m), n=deepcopy(n)) __else_outputs = __else_block(b=deepcopy(n), m=deepcopy(m), n=deepcopy(n)) x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)] # which will evaluate to: x = g.switch(a, y, g.switch(b, m, n)) z = g.switch(a, 1, g.switch(b, 2, 3))
Note
Defining a symbol inside the if statement requires that means that that symbol has be defined in all conditional branches.
Since all branches of the if statement are traced, symbols used in the if statement must be
deepcopy()
able.deepcopy()
is used to ensure that variables in different branches don't interfere with each other.Args
ast
- AST to transform if else statements into the plotting of a Switch Node.
Returns
The given ast with if else statements to be transformed into a the plotting of a Switch Node.
Expand source code
def transform_ifelse(ast: AST) -> AST: """Transforms if/elif/else statements to plot Switch Nodes on computational graph. Requires the `analyze_activity()` to analyze the AST first. Transforms if/elif/else statements to a functions that using the Graph Plotter trace the operations performed on each conditional branch. Collects operations for each conditional branch using the Graph plotter and combines them into a conditional Switch node. Example: if a: x = y z = 1 elif b: x = m z = 2 else: x = n z = 3 # should be transformed into: def __if_block(y, m, n): x = y z = 1 return x, z def __else_block(b, m, n): def __if_block(m, n): x = m z = 2 return x, z def __else_block(m, n): x = n z = 2 return x, z from copy import deepcopy __if_condition = deepcopy(b) __if_outputs = if_block(m=deepcopy(m), n=deepcopy(n)) __else_outputs = else_block(m=deepcopy(m), n=deepcopy(n)) x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)] return x, z from copy import deepcopy __if_condition = deepcopy(a) __if_outputs = __if_block(b=deepcopy(b), m=deepcopy(m), n=deepcopy(n)) __else_outputs = __else_block(b=deepcopy(n), m=deepcopy(m), n=deepcopy(n)) x, z = [g.switch(__if_condition, if_out, else_out) for if_out, else_out in zip(__if_outputs, __else_outputs)] # which will evaluate to: x = g.switch(a, y, g.switch(b, m, n)) z = g.switch(a, 1, g.switch(b, 2, 3)) Note: Defining a symbol inside the if statement requires that means that that symbol has be defined in all conditional branches. Since all branches of the if statement are traced, symbols used in the if statement must be `deepcopy()`able. `deepcopy()` is used to ensure that variables in different branches don't interfere with each other. Args: ast: AST to transform if else statements into the plotting of a Switch Node. Returns: The given ast with if else statements to be transformed into a the plotting of a Switch Node. """ def do_transform(ifelse_ast: AST) -> AST: # filter out non-ifelse statements if not isinstance(ifelse_ast, If): return ifelse_ast # convert ifelse condition branches into functions with the arguments # set to the names of the base input symbols and return values set to output symbols. # base symbols are use to generate the arguments as the full symbol might be qualified # ie A.x which is not a valid argument name: https://github.com/joeltio/bento-box/issues/37 args = list(ifelse_ast.base_in_syms.keys()) returns = list(ifelse_ast.output_syms.keys()) fn_asts = [ wrap_func_ast( name=name, args=args, block=block, returns=returns, # zip() requires the returned outputs to be iterable return_tuple=True, ) for name, block in zip( ["__if_block", "__else_block"], [ifelse_ast.body, ifelse_ast.orelse] ) ] # deepcopy the condition before tracing the if/else block functions to # prevent side effects tracing from interfering with the condition. condition_ast = name_ast("__if_condition") eval_condition_ast = assign_ast( targets=[condition_ast], values=[call_func_ast("deepcopy", args=[ifelse_ast.test])], ) # call if/else block functions to trace results of evaluating each branch # of the conditional if/else block functions have arguments with the same # names as symbols we have to pass in. # deepcopy to prevent input symbols from being passed by reference and # causing interference between branches https://github.com/joeltio/bento-box/issues/39 import_deepcopy_ast = import_from_ast(module="copy", names=["deepcopy"]) call_args = { a: call_func_ast( fn_name="deepcopy", args=[name_ast(a)], ) for a in args } branch_outputs = [name_ast(n) for n in ["__if_outputs", "__else_outputs"]] call_fn_asts = [ assign_ast( targets=[target], values=[call_func_ast(fn_ast.name, args=call_args)], ) for target, fn_ast in zip(branch_outputs, fn_asts) ] # create switch nodes for each output symbol via list comprehension plot_switch_fn = parse_ast(Plotter.switch).body[0] # g.switch(test, if_out, else_out) call_switch_ast = call_func_ast( fn_name=plot_switch_fn.name, args={ "condition": condition_ast, "true": name_ast("if_out"), "false": name_ast("else_out"), }, attr_parent=ast.convert_fn.plotter_name, ) # (symbol, ...) = [g.switch(...) for if_out, else_out in zip(if_outputs, else_outputs)] switch_asts = assign_ast( targets=[name_ast(r, ctx=Store()) for r in returns], values=[ ListComp( elt=call_switch_ast, generators=[ comprehension( target=Tuple( elts=[ name_ast("if_out"), name_ast("else_out"), ], ctx=Load(), ), iter=call_func_ast( fn_name="zip", args=branch_outputs, ), ifs=[], is_async=False, ) ], ) ], force_tuple=True, ) # wrap transformed code block as single AST node return wrap_block_ast( block=fn_asts + [import_deepcopy_ast, eval_condition_ast] + call_fn_asts + [switch_asts], ) ifelse_transform = FuncASTTransform(do_transform) # apply ifelse transform to AST ast = ifelse_transform.visit(ast) mod = load_ast_module(ast) return ast
def transform_ternary(ast: _ast.AST) ‑> _ast.AST
-
Transforms ternary conditional to plot Switch Nodes on computational graph.
Requires the
analyze_convert_fn()
to analyze the AST first. Transforms ternary statements in the forma if b else c
into a call to the Graph Plotter to plot a Switch Nodeg.switch(b, a, c)
on the computation graph.Args
ast
- AST to transform ternary into the plotting of a Switch Node.
Returns
The given ast with ternary statements transformed into the plotting of a Switch Node.
Expand source code
def transform_ternary(ast: AST) -> AST: """Transforms ternary conditional to plot Switch Nodes on computational graph. Requires the `analyze_convert_fn()` to analyze the AST first. Transforms ternary statements in the form `a if b else c` into a call to the Graph Plotter to plot a Switch Node `g.switch(b, a, c)` on the computation graph. Args: ast: AST to transform ternary into the plotting of a Switch Node. Returns: The given ast with ternary statements transformed into the plotting of a Switch Node. """ def do_transform(ternary_ast: AST) -> AST: # filter out non-ternary expressions if not isinstance(ternary_ast, IfExp): return ternary_ast # obtain AST of calling the plotter to plotting a switch node plot_switch_fn = parse_ast(Plotter.switch).body[0] return call_func_ast( fn_name=plot_switch_fn.name, args={ "condition": ternary_ast.test, "true": ternary_ast.body, "false": ternary_ast.orelse, }, attr_parent=ast.convert_fn.plotter_name, ) ternary_transform = FuncASTTransform(do_transform) # apply ternary transform to AST ast = ternary_transform.visit(ast) return ast