Macro implementation in pure Python

Summary

This post discusses the implementation of macros in Python through meta-programming techniques leveraging the `inspect` and `ast` modules to manipulate functions at a meta-level. ### Key Concepts: 1. **Source Code Retrieval**: Using the `inspect` module, the post demonstrates how to fetch the source code for a function dynamically. This involves capturing the current function's stack frame and extracting specific lines of code. 2. **AST Parsing**: The retrieved source is then parsed into an abstract syntax tree (AST) using Python's built-in `ast` module, which allows further manipulation based on the tree structure of the code. 3. **Macro Functionality**: The concept of a macro is introduced, allowing functions to be transformed such that when they are called, their arguments can be treated as code (AST nodes) or strings, rather than typical runtime values. This is facilitated through a decorator approach that dynamically fetches arguments' ASTs and processes them according to specified macro levels. 4. **Macro Levels**: The implementation defines `MacroLevel` to distinguish how arguments are processed, either as AST nodes or raw strings. 5. **Type Compliance**: To enhance compatibility with IDE syntax checks and highlighting, a `__Placeholder` class is defined, allowing placeholder objects to behave in flexible and expected ways during code writing, without real-time computation. 6. **Practical Example**: The post demonstrates a practical application of the macro system to create a lambda-like shorthand that can interpret and evaluate expressions using placeholders for positional arguments. ### Technical Details: - **Inspect Module**: Utilized for fetching frame and source information. - **AST Module**: Used for parsing and manipulating code at the syntax tree level. - **Macros**: Functions treated as macros through decorators, allowing code-level manipulations during runtime. - **Error Handling**: The implementation considers potential conflicts in placeholder usage, ensuring that the macro system is robust. This approach introduces a powerful paradigm shift in Python scripting, providing meta-programming capabilities typically reserved for languages with native macro support.

Recently, I got to write some more Python due to my algorithm class. Since the contents were easy, I started to explore around in REPL, and I found quite a lot of metadata embedded in Python objects. Thus, it's totally possible to retrieve source code of nearly everything. With that in mind, it's tempting to create some systematic way of meta-programming.

Retrieving the source code

While there's quite a lot of ways to retrieve the source code, I chose to directly use the inspect module, which provides access to the stack of frames (function calls) and their corresponding source code.

stack = inspect.stack()
frame_info = stack[1]
 
frame = frame_info.frame
file = inspect.getsourcefile(frame)
position = frame_info.positions
 
# get source code
lines = map(
    lambda line: linecache.getline(file, line),
    range(position.lineno, position.end_lineno + 1),
)
lines = list(lines)
if len(lines) > 1:
    lines[0] = lines[0][position.col_offset :]
    lines[-1] = lines[-1][: position.end_col_offset]
else:
    # if only one line, slice once since col_offset and end_col_offset
    # are relative to the same line
    lines[0] = lines[0][position.col_offset : position.end_col_offset]
src = "".join(lines)

After retrieving the corresponding source code, we can parse it into an AST using python's built-in ast module.

tree = ast.parse(src)
expr = ast_tree.body[0].value

After this, how we process the expression depends on how this function is called. Here, I only considered when function is called directly macro(...) and when called with subscript macro[...].

if isinstance(expr, ast.Call):
    return expr.args
elif isinstance(expr, ast.Subscript):
    return expr.slice
else:
    raise ValueError("Could not get arguments")

Once we have the arguments, we can process them as we want.

Putting everything together, we can create a macro decorator that transforms normal functions into macros.

@enum.unique
class MacroLevel(enum.Enum):
    AST = 1
    STR = 2
 
 
def macro(*, level: MacroLevel = MacroLevel.AST):
    R = TypeVar("R")
 
    def decorator(func: Callable[..., R]) -> Callable[..., R]:
        def wrapper(*_, **__):
            stack = inspect.stack()
            frame_info = stack[1]
            args_ast = _get_args_ast(frame_info)
 
            # fmt: off
            args = pipe(
                level,
                match(
                    case(MacroLevel.AST) >> args_ast,
                    case(MacroLevel.STR) >> matchV(args_ast)(
                        case(lambda x: isinstance(x, list)) >> (lambda x: [ast.unparse(arg) for arg in x]),
                        default >> ast.unparse,
                    ),
                )
            )
            # fmt: on
            return func(args, frame_info.frame)
 
        return cast(Callable[..., R], wrapper)
 
    return decorator

where _get_args_ast is the function that retrieves the arguments' AST.

Here, I also implement a MacroLevel which controls how those args are passed to the function.

Get the typings correct

While the implementation above works, it's annoying when editor's type check and syntax highlighting keeps complaining. To fix this, I create a PlaceHolder class, which utilizes dunder methods in Python to allow basically any operations.

class __Placeholder:
    def __add__(self, _):
        return self
 
    def __sub__(self, _):
        return self
 
    def __mul__(self, _):
        return self
    # ...

Example usage of the macro

After completing the macro, it's tempting to implement some powerful features with it. Here, I'm inspired by the MacroPy lib to create a similar shorthand for lambda expression (but more powerful as it allows to specify arbitrary arguments through _1, _2, etc.)

class __f:
    @staticmethod
    @macro(level=MacroLevel.AST)
    def __getitem__(args_src: ast.expr, frame: FrameType):
        placeholders = set()
        for node in ast.walk(args_src):
            if isinstance(node, ast.Name) and node.id in [
                "_",
                *(f"_{i}" for i in range(1, 10)),
            ]:
                placeholders.add(node.id)
 
        if "_" in placeholders and placeholders - {"_"}:
            raise ValueError(
                "A quick lambda should use either _ or _1, _2, ..., but not both."
            )
 
        if "_" in placeholders:
            args = ["_"]
        else:
            args = sorted(placeholders, key=lambda x: int(x[1:]))
 
        args = map(lambda x: ast.arg(arg=x, annotation=None), args)
        ast_func = ast.Lambda(
            args=ast.arguments(
                list(args), [], kwonlyargs=[], kw_defaults=[], defaults=[]
            ),
            body=args_src,
        )
        return eval(
            ast.unparse(ast_func),
            frame.f_globals,
            frame.f_locals,
        )
 
 
f = __f()
_: Any = __Placeholder()
_1, _2, _3, _4, _5, _6, _7, _8, _9 = [_] * 9