Recently, I got to write some more Python due to my algorithm class. Since the contents were easy, I started to explore around in REPL, and I found quite a lot of metadata embedded in Python objects. Thus, it's totally possible to retrieve source code of nearly everything. With that in mind, it's tempting to create some systematic way of meta-programming.
Retrieving the source code
While there's quite a lot of ways to retrieve the source code, I chose
to directly use the inspect
module, which provides access to the stack
of frames (function calls) and their corresponding source code.
stack = inspect.stack()
frame_info = stack[1]
frame = frame_info.frame
file = inspect.getsourcefile(frame)
position = frame_info.positions
# get source code
lines = map(
lambda line: linecache.getline(file, line),
range(position.lineno, position.end_lineno + 1),
)
lines = list(lines)
if len(lines) > 1:
lines[0] = lines[0][position.col_offset :]
lines[-1] = lines[-1][: position.end_col_offset]
else:
# if only one line, slice once since col_offset and end_col_offset
# are relative to the same line
lines[0] = lines[0][position.col_offset : position.end_col_offset]
src = "".join(lines)
After retrieving the corresponding source code, we can parse it into an AST using
python's built-in ast
module.
tree = ast.parse(src)
expr = ast_tree.body[0].value
After this, how we process the expression depends on how this function is called.
Here, I only considered when function is called directly macro(...)
and when
called with subscript macro[...]
.
if isinstance(expr, ast.Call):
return expr.args
elif isinstance(expr, ast.Subscript):
return expr.slice
else:
raise ValueError("Could not get arguments")
Once we have the arguments, we can process them as we want.
Putting everything together, we can create a macro
decorator that transforms
normal functions into macros.
@enum.unique
class MacroLevel(enum.Enum):
AST = 1
STR = 2
def macro(*, level: MacroLevel = MacroLevel.AST):
R = TypeVar("R")
def decorator(func: Callable[..., R]) -> Callable[..., R]:
def wrapper(*_, **__):
stack = inspect.stack()
frame_info = stack[1]
args_ast = _get_args_ast(frame_info)
# fmt: off
args = pipe(
level,
match(
case(MacroLevel.AST) >> args_ast,
case(MacroLevel.STR) >> matchV(args_ast)(
case(lambda x: isinstance(x, list)) >> (lambda x: [ast.unparse(arg) for arg in x]),
default >> ast.unparse,
),
)
)
# fmt: on
return func(args, frame_info.frame)
return cast(Callable[..., R], wrapper)
return decorator
where _get_args_ast
is the function that retrieves the arguments' AST.
Here, I also implement a MacroLevel
which controls how those args are passed
to the function.
Get the typings correct
While the implementation above works, it's annoying when editor's type check and
syntax highlighting keeps complaining. To fix this, I create a PlaceHolder
class,
which utilizes dunder methods in Python to allow basically any operations.
class __Placeholder:
def __add__(self, _):
return self
def __sub__(self, _):
return self
def __mul__(self, _):
return self
# ...
Example usage of the macro
After completing the macro, it's tempting to implement some powerful features with it.
Here, I'm inspired by the MacroPy
lib to create a similar shorthand for lambda
expression (but more powerful as it allows to specify arbitrary arguments through
_1
, _2
, etc.)
class __f:
@staticmethod
@macro(level=MacroLevel.AST)
def __getitem__(args_src: ast.expr, frame: FrameType):
placeholders = set()
for node in ast.walk(args_src):
if isinstance(node, ast.Name) and node.id in [
"_",
*(f"_{i}" for i in range(1, 10)),
]:
placeholders.add(node.id)
if "_" in placeholders and placeholders - {"_"}:
raise ValueError(
"A quick lambda should use either _ or _1, _2, ..., but not both."
)
if "_" in placeholders:
args = ["_"]
else:
args = sorted(placeholders, key=lambda x: int(x[1:]))
args = map(lambda x: ast.arg(arg=x, annotation=None), args)
ast_func = ast.Lambda(
args=ast.arguments(
list(args), [], kwonlyargs=[], kw_defaults=[], defaults=[]
),
body=args_src,
)
return eval(
ast.unparse(ast_func),
frame.f_globals,
frame.f_locals,
)
f = __f()
_: Any = __Placeholder()
_1, _2, _3, _4, _5, _6, _7, _8, _9 = [_] * 9