Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion jac/jaclang/compiler/passes/main/predynamo_pass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Pytorch Fix Pass."""

import ast as ast3
from typing import Optional, TypeVar, cast
from copy import deepcopy
from typing import Optional, Sequence, TypeVar, cast

import jaclang.compiler.unitree as uni
from jaclang.compiler.constant import Tokens as Tok
Expand All @@ -14,6 +15,19 @@
class PreDynamoPass(UniPass):
"""Pre-Dynamo pass for PyTorch."""

def before_pass(self) -> None:
"""Before pass."""
self.needs_gm_rt = False # whether we need to import graphmend_runtime
self._HOISTABLE_CALLS = {
"print",
"logging",
"sys.stdout.write",
"sys.stderr.write",
"sys.stdout.flush",
"sys.stderr.flush",
}
return super().before_pass()

def enter_node(self, node: uni.UniNode) -> None:
"""Enter node."""
super().enter_node(node)
Expand Down Expand Up @@ -110,6 +124,169 @@ def check_call(self, node: uni.ExprStmt) -> Optional[tuple]:
return (call.target, name, tensor_expr, kwargs)
return None

def _is_io_call(
self, node: uni.FuncCall
) -> bool: # TODO: make sure these are not user overloaded
"""Check if a function call is an I/O operation that should be hoisted."""
if isinstance(node.target, uni.Name):
return node.target.value in self._HOISTABLE_CALLS
elif isinstance(node.target, uni.AtomTrailer):
parts = []
current = node.target
while isinstance(current, uni.AtomTrailer):
if hasattr(current, "right") and isinstance(current.right, uni.Name):
parts.append(current.right.value)
current = current.target
if isinstance(current, uni.Name):
parts.append(current.value)

return any(parts) in self._HOISTABLE_CALLS
return False

def _replace_io_call(self, node: uni.FuncCall) -> uni.FuncCall:
"""Return an I/O function call with a call to the hoisted version."""
params = deepcopy(node.params)
tuple_params = uni.TupleVal(values=cast(Sequence[uni.Expr], params), kid=params)
io_name = node.target
if isinstance(io_name, uni.Name):
io_str = self.gen_name(node, Tok.STRING, f'"{io_name.value}"')
else:
io_str = self.gen_name(node, Tok.STRING, '"unknown_io"')
lpr = self.gen_name(node, Tok.LPAREN, "(")
rpr = self.gen_name(node, Tok.RPAREN, ")")
dict_val = uni.DictVal(kv_pairs=[], kid=[lpr, rpr])
args = [io_str, tuple_params, dict_val]
gm_name = self.gen_name(node, Tok.NAME, "_gm_io")
append_attr = self.gen_name(node, Tok.NAME, "append")
func_name = uni.AtomTrailer(
target=gm_name,
right=append_attr,
is_attr=True,
is_null_ok=False,
kid=[gm_name, append_attr],
)
return uni.FuncCall(
target=func_name,
params=args,
genai_call=None,
kid=[func_name] + args,
)

def _create_ability(self, node: uni.Ability) -> tuple:
"""Create ability node."""
ability_name = f"__gm_core_{node.name_ref._sym_name}"
name = self.gen_name(node, Tok.NAME, ability_name)
name.py_ctx_func = ast3.Load
kid = [name]
ability = uni.Ability(
name_ref=name,
is_async=False,
is_override=False,
is_static=False,
is_abstract=False,
access=None,
signature=deepcopy(node.signature),
body=deepcopy(node.body),
kid=kid,
)

call = uni.FuncCall(
target=name,
params=[],
genai_call=None,
kid=[name],
)
gm_ret, gm_io = self.gen_name(node, Tok.NAME, "_gm_ret"), self.gen_name(
node, Tok.NAME, "_gm_io"
)
gm_ret.py_ctx_func = ast3.Store
gm_io.py_ctx_func = ast3.Store
assign_target = uni.TupleVal(values=[gm_ret, gm_io], kid=[gm_ret, gm_io])
assign_target.name_spec.py_ctx_func = ast3.Store
assign_expr = uni.Assignment(
target=[assign_target],
value=call,
type_tag=None,
kid=[assign_target, call],
)

gm_name = self.gen_name(node, Tok.NAME, "_gm_rt")
flush_name = self.gen_name(node, Tok.NAME, "graphmend_flush")
gm_name.py_ctx_func = ast3.Load
flush_name.py_ctx_func = ast3.Load
flush_func_name = uni.AtomTrailer(
target=gm_name,
right=flush_name,
is_attr=True,
is_null_ok=False,
kid=[gm_name, flush_name],
)
gm_io_new = deepcopy(gm_io)
gm_io_new.py_ctx_func = ast3.Load
flush_call = uni.FuncCall(
target=flush_func_name,
params=[gm_io_new],
genai_call=None,
kid=[flush_func_name, gm_io_new],
)
flush_expr = uni.ExprStmt(expr=flush_call, in_fstring=False, kid=[flush_call])
gm_ret_new = deepcopy(gm_ret)
gm_ret_new.py_ctx_func = ast3.Load
return_stmt = uni.ReturnStmt(expr=gm_ret_new, kid=[gm_ret_new])

out_body_parts = (assign_expr, flush_expr, return_stmt)
return (ability, out_body_parts)

def exit_module(self, node: uni.Module) -> None:
"""Exit module."""
if not self.needs_gm_rt:
return
imp_name = self.gen_name(node, Tok.NAME, "graphmend_runtime")
imp_alias = self.gen_name(node, Tok.NAME, "_gm_rt")
imp_alias.py_ctx_func = ast3.Store
item = uni.ModuleItem(name=imp_name, alias=imp_alias, kid=[imp_name])
imp = uni.Import(
from_loc=None,
items=[item],
is_absorb=False,
kid=[item],
)
node.body = [imp] + list(node.body)
node.kid = [imp] + list(node.kid)

def exit_ability(self, node: uni.Ability) -> None:
"""Exit ability."""
if getattr(node, "is_hoistable", False):
self.needs_gm_rt = True
ability_node, out_body_parts = self._create_ability(node)
if isinstance(node.body, list):
body = node.body
elif isinstance(node.body, uni.ImplDef) and isinstance(
node.body.body, list
):
body = node.body.body
for i in body:
if isinstance(i, uni.FuncCall) and self._is_io_call(i):
new_call = self._replace_io_call(i)
self.replace_node(new_call, i, "body")
node.body = [ability_node, *out_body_parts]
node.kid = [node.kid[0], ability_node, *out_body_parts]

def exit_func_call(self, node: uni.FuncCall) -> None:
"""Exit function call."""
if self._is_io_call(node):
ability_node = node.find_parent_of_type(uni.Ability)
if ability_node is not None:
ability_node.is_hoistable = True # type: ignore[attr-defined]

new_func_call = self._replace_io_call(node)
if isinstance(node.parent, uni.ExprStmt):
node.parent.expr = new_func_call
new_func_call.parent = node.parent
if hasattr(node.parent, "kid") and node in node.parent.kid:
idx = node.parent.kid.index(node)
node.parent.kid[idx] = new_func_call

def exit_if_stmt(self, node: uni.IfStmt) -> None:
"""Exit if statement."""
a0 = node.body[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def print_func(a:int):
print(a)
a = a+1
return a

print(print_func(5))
17 changes: 17 additions & 0 deletions jac/jaclang/compiler/passes/main/tests/test_predynamo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,20 @@ def test_predynamo_fix3(self) -> None:
unparsed_code = code_gen.unparse()
self.assertIn("__inv_freq = torch.where(", unparsed_code)
self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)

def test_predynamo_io(self) -> None:
"""Test I/O transformation."""
captured_output = io.StringIO()
sys.stdout = captured_output
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_io.py"))
sys.stdout = sys.__stdout__
unparsed_code = code_gen.unparse()
settings.predynamo_pass = False
captured_output = io.StringIO()
sys.stdout = captured_output
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_io.py"))
sys.stdout = sys.__stdout__
settings.predynamo_pass = True
unparsed_code_original = code_gen.unparse()

self.assertNotEqual(unparsed_code, unparsed_code_original)
96 changes: 96 additions & 0 deletions jac/jaclang/runtimelib/graphmend_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Graphmend runtime I/O utilities."""

from __future__ import annotations

import logging as _lg
import sys as _sys
from contextlib import suppress


def _graphmend_format(x: object) -> str:
"""
Tensor-aware, sync-avoiding formatter.

Avoids .cpu()/.numpy() to prevent device sync in hot paths.
"""
try:
import torch # local import to avoid hard dep if torch is absent

if isinstance(x, torch.Tensor):
try:
shape = tuple(x.shape)
except Exception:
shape = () # Empty tuple instead of string
try:
dtype = str(x.dtype)
except Exception:
dtype = "<unknown>"
try:
device = str(x.device)
except Exception:
device = "<unknown>"
return f"<Tensor shape={shape} dtype={dtype} device={device}>"
except Exception:
pass
try:
return str(x)
except Exception:
return repr(x)


def _graphmend_flush(logs: list[tuple]) -> None:
"""
Replay buffered I/O records *after* the compute-heavy region.

Each record: (kind:str, args:tuple, kwargs:dict, lineno:int).
"""
import builtins as _bi # defer import to keep runtime light

for rec in logs:
try:
kind, args, kwargs, _lineno = rec
except Exception:
continue

# Format args once, lazily
fargs = tuple(_graphmend_format(a) for a in (args or ()))

if kind == "print":
# honor sep/end/flush if present; always print to stdout/stderr only
kw = dict(kwargs or {})
# Block custom 'file' to avoid side-effects; route to stdout
kw.pop("file", None)
_bi.print(
*fargs, **{k: v for k, v in kw.items() if k in ("sep", "end", "flush")}
)

elif kind == "logging":
level = (kwargs or {}).get("__level__", "info").lower()
msg = " ".join(fargs)
log_fn = getattr(_lg, level, _lg.info)
safe_kwargs = {
k: v for k, v in (kwargs or {}).items() if k not in ("__level__",)
}
try:
log_fn(msg, **safe_kwargs)
except TypeError:
log_fn(msg)

elif kind == "syswrite":
stream = (kwargs or {}).get("__stream__", "stdout")
s = "".join(fargs)
tgt = _sys.stdout if stream == "stdout" else _sys.stderr
try:
tgt.write(s)
except Exception:
_bi.print(s, end="")

elif kind == "sysflush":
stream = (kwargs or {}).get("__stream__", "stdout")
tgt = _sys.stdout if stream == "stdout" else _sys.stderr
with suppress(Exception):
tgt.flush()

else:
# Unknown kind → best-effort
_bi.print(*fargs)
Loading