diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala index b65d7427d6..355c8fc475 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Instrumentation.scala @@ -4,12 +4,371 @@ package codegen import utils.* import hkmc2.Message.MessageContext -class Instrumentation(using Raise) extends BlockTransformer(new SymbolSubst()): - def transform(prgm: Program) = Program(prgm.imports, applyBlock(prgm.main)) - - override def applyDefn(d: Defn)(k: Defn => Block): Block = d match - case defn: ClsLikeDefn => - if defn.isym.defn.exists(_.hasStagedModifier.isDefined) && defn.companion.isDefined - then raise(WarningReport(msg"`staged` keyword doesn't do anything currently." -> defn.sym.toLoc :: Nil)) - super.applyDefn(defn)(k) - case b => super.applyDefn(b)(k) +import scala.collection.mutable.HashMap +import scala.util.chaining._ + +import mlscript.utils.*, shorthands.* + +import semantics.* +import semantics.Elaborator.State + +import syntax.{Literal, Tree} + +// it seems some logic should be deferred to BlockTransformer to dedup code +// but it doesn't accept the current context, so applications seem limited + +// it should be possible to cache some common constructions (End, Option) into the context +// this avoids having to rebuild the same shapes everytime they are needed + +// transform Block to Block IR so that it can be instrumented in mlscript +class InstrumentationImpl(using State): + type ArgWrappable = Path | StagedPath | Symbol + type Context = HashMap[Path, StagedPath] + + def asArg(x: ArgWrappable): Arg = + x match + case p: Path => p.asArg + case s: StagedPath => s.code.asArg + case l: Symbol => l.asPath.asArg + + // null and undefined are missing + def toValue(lit: Str | Int | BigDecimal | Bool): Value = + val l = lit match + case i: Int => Tree.IntLit(i) + case b: Bool => Tree.BoolLit(b) + case s: Str => Tree.StrLit(s) + case n: BigDecimal => Tree.DecLit(n) + Value.Lit(l) + + def concat(b1: Block, b2: Block): Block = + b1.mapTail { + case _: End => b2 + case _ => ??? + } + + extension [A, B](ls: Ls[(A => B) => B]) + def collectApply(f: Ls[A] => B): B = + // defer applying k while prepending new elements to the list + ls.foldRight((_: Ls[A] => B)(Nil))((headCont, tailCont) => + k => + headCont: head => + tailCont: tail => + k(head :: tail) + )(f) + + // helpers for constructing Block + + def assign(res: Result, symName: Str = "tmp")(k: Path => Block): Block = + // TODO: skip assignment if res: Path? + val sym = new TempSymbol(N, symName) + Assign(sym, res, k(sym.asPath)) + + def tuple(elems: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + assign(Tuple(false, elems.map(asArg)), symName)(k) + + def ctor(cls: Path, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + assign(Instantiate(false, cls, args.map(asArg)), symName)(k) + + // isMlsFun is probably always true? + def call(fun: Path, args: Ls[ArgWrappable], isMlsFun: Bool = true, symName: Str = "tmp")(k: Path => Block): Block = + assign(Call(fun, args.map(asArg))(isMlsFun, false, false), symName)(k) + + // helpers for instrumenting Block + + def blockMod(name: Str) = summon[State].blockSymbol.asPath.selSN(name) + def optionMod(name: Str) = summon[State].optionSymbol.asPath.selSN(name) + + def blockCtor(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: StagedPath => Block): Block = + ctor(blockMod(name), args, symName)(p => k(StagedPath(p))) + def optionSome(arg: ArgWrappable, symName: Str = "tmp")(k: Path => Block): Block = + ctor(optionMod("Some"), Ls(arg), symName)(k) + def optionNone(symName: Str = "tmp")(k: Path => Block): Block = + assign(optionMod("None"), symName)(k) + + def blockCall(name: Str, args: Ls[ArgWrappable], symName: Str = "tmp")(k: Path => Block): Block = + call(blockMod(name), args, symName = symName)(k) + + case class StagedPath(code: Path) + + // linking functions defined in MLscipt + + def fnPrintCode(p: StagedPath)(k: Block): Block = + // discard result, we only care about side effect + blockCall("printCode", Ls(p))(_ => k) + + def fnConcat(p1: StagedPath, p2: StagedPath, symName: String = "concat")(k: Path => Block): Block = + blockCall("concat", Ls(p1, p2), symName)(k) + + // transformation helpers + + def transformSymbol(sym: Symbol, symName: Str = "sym")(k: StagedPath => Block): Block = + sym match + case clsSym: ClassSymbol => + transformParamsOpt(clsSym.defn.get.paramsOpt): paramsOpt => + blockCtor("ClassSymbol", Ls(toValue(sym.nme), paramsOpt), symName)(k) + case t: TermSymbol if t.defn.exists(_.sym.asCls.isDefined) => + transformSymbol(t.defn.get.sym.asCls.get, symName)(k) + case _ => blockCtor("Symbol", Ls(toValue(sym.nme)), symName)(k) + + def transformOption[A](xOpt: Opt[A], f: A => (Path => Block) => Block)(k: Path => Block): Block = + xOpt match + case S(x) => f(x)(optionSome(_)(k)) + case N => optionNone()(k) + + // instrumentation rules + + def ruleLit(l: Value.Lit, symName: String = "lit")(k: StagedPath => Block): Block = + blockCtor("ValueLit", Ls(l), symName)(k) + + // not in formalization + def ruleVar(r: Value.Ref, symName: String = "var")(k: StagedPath => Block): Block = + val Value.Ref(l, disamb) = r + transformSymbol(disamb.getOrElse(l)): sym => + blockCtor("ValueRef", Ls(sym), symName)(k) + + def ruleTup(t: Tuple, symName: String = "tup")(using Context)(k: StagedPath => Block): Block = + assert(!t.mut, "mutable tuple not supported") + transformArgs(t.elems): xs => + tuple(xs.map(_._1)): codes => + blockCtor("Tuple", Ls(codes), symName)(k) + + def ruleSel(s: Select, symName: String = "sel")(using Context)(k: StagedPath => Block): Block = + val Select(p, i @ Tree.Ident(name)) = s + transformPath(p): x => + blockCtor("Symbol", Ls(toValue(name))): name => + blockCtor("Select", Ls(x, name), symName)(k) + + def ruleDynSel(d: DynSelect, symName: String = "dynsel")(using Context)(k: StagedPath => Block): Block = + transformPath(d.qual): x => + transformPath(d.fld): y => + blockCtor("DynSelect", Ls(x, y, toValue(d.arrayIdx)), symName)(k) + + def ruleApp(c: Call, symName: String = "app")(using Context)(k: StagedPath => Block): Block = + transformPath(c.fun): fun => + transformArgs(c.args): args => + tuple(args.map(_._1)): tup => + blockCtor("Call", Ls(fun, tup), symName)(k) + + def ruleInst(i: Instantiate, symName: String = "inst")(using Context)(k: StagedPath => Block): Block = + val Instantiate(mut, cls, args) = i + assert(!mut, "mutable instantiation not supported") + transformArgs(args): xs => + val sym = cls match + // TODO: if class is staged, we can just use Symbol without storing the arguments + case Value.Ref(l, S(disamb)) => transformSymbol(disamb) + case s: Select if s.symbol.isDefined => transformSymbol(s.symbol.get) + case _ => transformSymbol(TempSymbol(N, "TODO")) + sym: sym => + // reuse instrumentation logic, shape of cls is discarded + // possible to skip this? this uses ruleVar, which is not in formalization + transformPath(cls): cls => + tuple(xs.map(_._1)): codes => + blockCtor("Instantiate", Ls(cls, codes), symName)(k) + + def ruleReturn(r: Return, symName: String = "return")(using Context)(k: (StagedPath, Context) => Block): Block = + transformResult(r.res): x => + blockCtor("Return", Ls(x.code, toValue(false)), symName): cde => + k(cde, summon) + + def ruleMatch(m: Match, symName: String = "match")(using Context)(k: (StagedPath, Context) => Block): Block = + val Match(p, ks, dflt, rest) = m + transformPath(p): x => + ruleBranches(x, p, ks, dflt): (stagedMatch, ctx1) => + transformBlock(rest)(using ctx1): (z, ctx2) => + fnConcat(stagedMatch, z, symName): cde => + k(StagedPath(cde), ctx2) + + def ruleAssign(a: Assign, symName: String = "assign")(using ctx: Context)(k: (StagedPath, Context) => Block): Block = + val Assign(x, r, b) = a + transformResult(r): y => + transformSymbol(x): xSym => + blockCtor("ValueRef", Ls(xSym)): xStaged => + // x should always be defined, either as an argument to the function or in a Scope Block + assert(ctx.get(x.asPath).isDefined) + val x2 = xStaged + (Assign(x, x2.code, _)): + given Context = ctx.clone() += x.asPath -> x2 + transformBlock(b): (z, ctx) => + blockCtor("Assign", Ls(xSym, y, z), symName)(k(_, ctx)) + + def ruleEnd(symName: String = "end")(k: StagedPath => Block): Block = + blockCtor("End", Ls(), symName)(k) + + def ruleBlk(b: Block)(using Context)(k: Path => Block): Block = + transformBlock(b)(k apply _.code) + + def ruleCls(cls: ClsLikeDefn, rest: Block)(using Context)(k: StagedPath => Block): Block = + assert(cls.companion.isEmpty, "nested module not supported") + (Define(cls, _)): + transformBlock(rest): p => + transformParamsOpt(cls.paramsOpt): paramsOpt => + transformSymbol(cls.isym): c => + optionNone(): none => // TODO: handle companion object + blockCtor("ClsLikeDefn", Ls(c, none)): cls => + blockCtor("Define", Ls(cls, p))(k) + + def ruleBranches(x: StagedPath, p: Path, arms: Ls[Case -> Block], dflt: Opt[Block], symName: String = "branches")(using Context)(k: (StagedPath, Context) => Block): Block = + def applyRuleBranch(cse: Case, block: Block)(f: StagedPath => (Context, StagedPath) => Block)(ctx: Context, x: StagedPath): Block = + ruleBranch(x, p, cse, block)(using ctx)((y, ctx, x) => f(y)(ctx, x)) + + val a = arms.map(applyRuleBranch).collectApply + ((f: (Ls[StagedPath], Context) => Block) => a(ys => (ctx, _) => f(ys, ctx))(summon, x)): (arms, ctx) => + tuple(arms): arms => + ruleEnd(): e => + // TODO: use transformOption here + def dfltStaged(k: (Path, Context) => Block) = dflt match + case S(dflt) => ruleWildCard(x, p, dflt): (dflt, ctx) => + optionSome(dflt.code)(k(_, ctx)) + case N => optionNone()(k(_, ctx)) + dfltStaged: (dflt, ctx) => + blockCtor("Match", Ls(x.code, arms, dflt, e), symName)(k(_, ctx)) + + def ruleBranch(x: StagedPath, p: Path, cse: Case, b: Block, symName: String = "branch")(using ctx: Context)(k: (StagedPath, Context, StagedPath) => Block): Block = + transformCase(cse): cse => + transformBlock(b)(using ctx.clone() += p -> x): (y, ctx) => + // TODO: use Arm type instead of Tup + tuple(Ls(cse, y), symName): cde => + val ret = StagedPath(cde) + k(ret, ctx.clone() -= p, x) + + def ruleWildCard(x: StagedPath, p: Path, b: Block)(using ctx: Context)(k: (StagedPath, Context) => Block): Block = + given Context = ctx.clone() += p -> x + transformBlock(b): (y, ctx) => + k(y, ctx.clone() -= p) + + // transformations of Block + + def transformPath(p: Path)(using ctx: Context)(k: StagedPath => Block): Block = + // rulePath + ctx.get(p).map(k).getOrElse: + p match + case r: Value.Ref => ruleVar(r)(k) + case l: Value.Lit => ruleLit(l)(k) + case s: Select => ruleSel(s)(k) + case d: DynSelect => ruleDynSel(d)(k) + case _ => ??? // not supported + + def transformResult(r: Result)(using Context)(k: StagedPath => Block): Block = + r match + case p: Path => transformPath(p)(k) + case t: Tuple => ruleTup(t)(k) + case i: Instantiate => ruleInst(i)(k) + case c: Call => ruleApp(c)(k) + case _ => ??? // not supported + + def transformArg(a: Arg)(using Context)(k: ((StagedPath, Bool)) => Block): Block = + val Arg(spread, value) = a + transformOption(spread, bool => assign(toValue(bool))): spreadStaged => + transformPath(value): value => + blockCtor("Arg", Ls(spreadStaged, value)): cde => + k(cde, spread.isDefined) + + def transformArgs(args: Ls[Arg])(using Context)(k: Ls[(StagedPath, Bool)] => Block): Block = + args.map(transformArg).collectApply(k) + + def transformParamList(ps: ParamList)(k: Path => Block) = + ps.params.map(p => transformSymbol(p.sym)).collectApply(tuple(_)(k)) + + def transformParamsOpt(pOpt: Opt[ParamList])(k: Path => Block) = + transformOption(pOpt, transformParamList)(k) + + def transformCase(cse: Case)(using Context)(k: StagedPath => Block): Block = + cse match + case Case.Lit(lit) => blockCtor("Lit", Ls(Value.Lit(lit)))(k) + case Case.Cls(cls, path) => + transformSymbol(cls): cls => + transformPath(path): path => + blockCtor("Cls", Ls(cls, path))(k) + case Case.Tup(len, inf) => blockCtor("Tup", Ls(len, inf).map(toValue))(k) + case Case.Field(name, safe) => ??? // not supported + + // f.owner returns an InnerSymbol, but we need BlockMemberSymbol of the module to call the function + // so we pass modSym instead + def transformFunDefn(modSym: BlockMemberSymbol, f: FunDefn): (FunDefn, Block) = + val genSym = BlockMemberSymbol(f.sym.nme + "_gen", Nil, true) + val sym = modSym.asPath.selSN(genSym.nme) + // NOTE: this debug printing only works for top-level modules, nested modules don't work + // TODO: remove it. only for test + val debug = + blockCtor("ValueLit", Ls(Value.Lit(Tree.UnitLit(false)))): undef => + // TODO: put correct parameters instead of End + // TODO: handle curried arguments + val argsList = f.params.map(ps => List.fill(ps.params.length)(undef)) + def makeCalls(k: Path => Block) = + argsList.foldRight(k)((args, cont) => res => call(res, args)(cont))(sym) + makeCalls: ret => + val p = StagedPath(ret) + fnPrintCode(p)(End()) + + val dSym = TermSymbol(f.dSym.k, f.dSym.owner, Tree.Ident(f.sym.nme + "_gen")) + val args = f.params.flatMap(_.params).map(_.sym) + val newBody = + ruleEnd(): end => + given Context = HashMap(args.map(s => Value.Ref(s, N) -> StagedPath(Value.Ref(s, N)))*) + transformBlock(f.body)(p => Return(p.code, false)) + val newFun = f.copy(sym = genSym, dSym = dSym, body = newBody)(false) + (newFun, debug) + + def transformDefine(d: Define)(using Context)(k: (StagedPath, Context) => Block): Block = + d.defn match + case c: ClsLikeDefn => + ruleCls(c, d.rest): p => + ruleEnd(): b => + fnPrintCode(p)(k(b, summon)) + case _: FunDefn | _: ValDefn => ??? + + // TODO + // discards result of sub + def transformBegin(b: Begin)(using Context)(k: (StagedPath, Context) => Block): Block = + transformBlock(b.sub): (sub, ctx) => + transformBlock(b.rest)(using ctx): (rest, ctx) => + fnConcat(sub, rest): block => + k(StagedPath(block), ctx) + + def transformScoped(s: Scoped)(using ctx: Context)(k: (StagedPath, Context) => Block): Block = + val Scoped(syms, body) = s + blockCtor("ValueLit", Ls(Value.Lit(Tree.UnitLit(false)))): undef => + val newCtx = ctx.clone() ++ syms.map(_.asPath -> undef) + transformBlock(body)(using newCtx): (p, ctx) => + k(p, ctx) + + // ruleBlk? + def transformBlock(b: Block)(using Context)(k: StagedPath => Block): Block = + transformBlock(b)((p, _) => k(p)) + + def transformBlock(b: Block)(using Context)(k: (StagedPath, Context) => Block): Block = + b match + case r: Return => ruleReturn(r)(k) + case a: Assign => ruleAssign(a)(k) + case d: Define => transformDefine(d)(k) + case End(_) => ruleEnd()(k(_, summon)) + case m: Match => ruleMatch(m)(k) + // temporary measure to accept returning an array + // use BlockTransformer here? + case b: Begin => transformBegin(b)(k) + // case Begin(b1, b2) => transformBlock(concat(b1, b2))(k) + case s: Scoped => transformScoped(s)(k) + case _ => ??? // not supported + +// TODO: rename as InstrumentationTransformer? +class Instrumentation(using State) extends BlockTransformer(new SymbolSubst()): + val impl = new InstrumentationImpl + + override def applyBlock(b: Block): Block = super.applyBlock(b) match + case d @ Define(defn, rest) => + defn match + // find modules with staged annotation + case c: ClsLikeDefn if c.companion.exists(_.isym.defn.exists(_.hasStagedModifier.isDefined)) => + val sym = c.sym.subst + val companion = c.companion.get + val (stagedMethods, debugPrintCode) = companion.methods + .map(impl.transformFunDefn(sym, _)) + .unzip + val newCtor = impl.transformBlock(companion.ctor)(using new HashMap())(_ => End()) + val newCompanion = companion.copy(methods = companion.methods ++ stagedMethods, ctor = newCtor) + val newModule = c.copy(sym = sym, companion = S(newCompanion)) + // debug is printed after definition + val debugBlock = debugPrintCode.foldRight(rest)(impl.concat) + Define(newModule, debugBlock) + case _ => d + case b => b diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 8d049148f2..50a9e50b9b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -20,7 +20,7 @@ import semantics.Term.{Throw => _, *} import semantics.Elaborator.{State, Ctx, ctx} import syntax.{Literal, Tree} -import hkmc2.syntax.Fun +import hkmc2.syntax.{Fun, Keyword} abstract class TailOp extends (Result => Block) @@ -264,14 +264,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx): mod.classCompanion match case S(comp) => comp.defn.getOrElse(wat("Module companion without definition", mod.companion)) case N => - ClassDef.Plain(mod.owner, syntax.Cls, new ClassSymbol(Tree.DummyTypeDef(syntax.Cls), mod.sym.id), + val clsSymb = new ClassSymbol(Tree.DummyTypeDef(syntax.Cls), mod.sym.id) + val stagedAnnots = mod.annotations.collect { + case Annot.Modifier(Keyword.`staged`) => Annot.Modifier(Keyword.`staged`) + } + val newDefn = ClassDef.Plain(mod.owner, syntax.Cls, clsSymb, mod.bsym, Nil, N, ObjBody(Blk(Nil, UnitVal())), S(mod.sym), - Nil, + stagedAnnots ) + clsSymb.defn = S(newDefn) + newDefn case _ => _defn reportAnnotations(defn, defn.extraAnnotations) val bufferableAnnots = defn.annotations.flatMap: @@ -1059,7 +1065,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val merged = MergeMatchArmTransformer.applyBlock(bufferable) val staged = - if config.stageCode then Instrumentation(using summon).applyBlock(merged) + if config.stageCode then Instrumentation().applyBlock(merged) else merged val res = diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index 00dc4e6160..53f695a9e2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -101,8 +101,9 @@ object Printer: case Select(qual, name) => val docQual = mkDocument(qual) doc"${docQual}.${name.name}" + case DynSelect(qual, fld, ai) => + doc"${mkDocument(qual)}.(${mkDocument(fld)})" case x: Value => mkDocument(x) - case _ => TODO(path) def mkDocument(result: Result)(using Raise, Scope): Document = result match case Call(fun, args) => doc"${mkDocument(fun)}(${args.map(mkDocument).mkDocument(", ")})" diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 6cd7793ae2..aecf080a45 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -252,7 +252,8 @@ object Elaborator: val prettyPrintSymbol = TempSymbol(N, "prettyPrint") val termSymbol = TempSymbol(N, "Term") val blockSymbol = TempSymbol(N, "Block") - val shapeSymbol = TempSymbol(N, "Shape") + val shapeSetSymbol = TempSymbol(N, "shapeSet") + val optionSymbol = TempSymbol(N, "option") val wasmSymbol = TempSymbol(N, "wasm") val effectSigSymbol = ClassSymbol(DummyTypeDef(syntax.Cls), Ident("EffectSig")) val nonLocalRetHandlerTrm = diff --git a/hkmc2/shared/src/test/mlscript-compile/Block.mls b/hkmc2/shared/src/test/mlscript-compile/Block.mls index e69de29bb2..d8ec8b1b3f 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Block.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Block.mls @@ -0,0 +1,176 @@ +import "./Predef.mls" +import "./Option.mls" +import "./StrOps.mls" + +open Predef +open StrOps +open Option + +module Block with... + +type Opt[A] = Option[A] + +// dependancies referenced in Block classes, referencing implementation in Term.mls + +type Literal = null | undefined | Str | Int | Num | Bool + +type ParamList = Array[Symbol] + +class Symbol(val name: Str) +type Ident = Symbol +// this is so that we're able to retrieve information about the class from the symbol +class ClassSymbol(val name: Str, val paramsOpt: Opt[ParamList]) extends Symbol(name) + +class Arg(val spread: Opt[Bool], val value: Path) + +class Case with + constructor + Lit(val lit: Literal) + Cls(val cls: Symbol, val path: Path) + Tup(val len: Int, val inf: Bool) + +class Result with + constructor + Call(val _fun: Path, val args: Array[Arg]) + Instantiate(val cls: Path, val args: Array[Arg]) // assume immutable + Tuple(val elems: Array[Arg]) // assume immutable + +class Path extends Result with + constructor + Select(val qual: Path, val name: Ident) + DynSelect(val qual: Path, val fld: Path, val arrayIdx: Bool) // is arrayIdx used? + ValueRef(val l: Symbol) + ValueLit(val lit: Literal) + +class Defn with + constructor + ValDefn(val sym: Symbol, val rhs: Path) + ClsLikeDefn(val sym: ClassSymbol, val companion: Opt[ClsLikeBody]) // companion unused + FunDefn(val sym: Symbol, val params: Array[ParamList], val body: Block, val stage: Bool) + +class ClsLikeBody(val isym: Symbol, val methods: Array[FunDefn], val publicFields: Array[[Symbol, Symbol]]) // unused + +class Block with + constructor + Match(val scrut: Path, val arms: Array[[Case, Block]], val dflt: Opt[Block], val rest: Block) + Return(val res: Result, val implct: Bool) + Assign(val lhs: Symbol, val rhs: Result, val rest: Block) + Define(val defn: Defn, val rest: Block) + End() + +fun concat(b1: Block, b2: Block) = if b1 is + Match(scrut, arms, dflt, rest) then Match(scrut, arms, dflt, concat(rest, b2)) + Return(res, implct) then b2 // discard return? + Assign(lhs, rhs, rest) then Assign(lhs, rhs, concat(rest, b2)) + Define(defn, rest) then Define(defn, concat(rest, b2)) + End() then b2 + +fun showBool(b: Bool) = if b then "true" else "false" + +fun showLiteral(l: Literal) = + if l is + undefined then "undefined" + null then "null" + String then "\"" + l.toString() + "\"" + else l.toString() + +fun showSymbol(s: Symbol) = + // console.log("printing " + s) + if s is + ClassSymbol(name, args) then + "ClassSymbol(" + "\"" + name + "\"" + + if args + is Some(args) then ":[" + args.map(showSymbol).join(", ") + "]" + is None then "" + + ")" + _ then "Symbol(" + "\"" + s.name + "\"" + ")" + +fun showIdent(i: Ident) = showSymbol(i) + +fun showPath(p: Path): Str = + if p is + Select(qual, name) then + "Select(" + showPath(qual) + ", " + showIdent(name) + ")" + DynSelect(qual, fld, arrayIdx) then + "DynSelect(" + showPath(qual) + ", " + showPath(fld) + ", " + showBool(arrayIdx) + ")" + ValueRef(l) then + "Ref(" + showSymbol(l) + ")" + ValueLit(lit) then + "Lit(" + showLiteral(lit) + ")" + +fun showArg(arg: Arg) = + if arg.spread is + Some(true) then "..." + Some(false) then ".." + else "" + + showPath(arg.value) + +fun showArgs(args: Array[Arg]) = + "[" + args.map(showArg).join(", ") + "]" + +// Case (match arm patterns) +fun showCase(c: Case): Str = + if c is + Lit(lit) then "Lit(" + showLiteral(lit) + ")" + Cls(cls, path) then "Cls(" + showSymbol(cls) + ", " + showPath(path) + ")" + Tup(len, inf) then "Tup(" + len + ", " + inf + ")" + +fun showResult(r: Result): Str = + if r is + Path then showPath(r) + Call(f, args) then "Call(" + showPath(f) + ", " + showArgs(args) + ")" + Instantiate(cls, args) then "Instantiate(" + showPath(cls) + ", " + showArgs(args) + ")" + Tuple(elems) then "Tuple(" + showArgs(elems) + ")" + +fun showParamList(ps: ParamList) = + "[" + ps.map(s => showSymbol(s)).join(", ") + "]" + +fun showDefn(d: Defn): Str = + if d is + ValDefn(sym, rhs) then + "ValDefn(" + showSymbol(sym) + ", " + showPath(rhs) + ")" + FunDefn(sym, params, body, stage) then + "FunDefn(" + showSymbol(sym) + ", " + + "(" + params.map(showParamList) + "), " + + showBlock(body) + ", " + + stage + ")" + ClsLikeDefn(sym, companion) then + // TODO: print rest of the arguments + "ClsLikeDefn(" + showSymbol(sym) + ", " + "TODO" + ")" + +fun showOptBlock(ob: Opt[Block]) = + if ob is Some(b) then showBlock(b) else "None" + +fun showArm(pair: Case -> Block) = + if pair is [cse, body] then showCase(cse) + " -> " + showBlock(body) else "" + +fun showBlock(b: Block): Str = + if b is + Match(scrut, arms, dflt, rest) then + "Match(" + + showPath(scrut) + ", " + + "[" + arms.map(showArm).join(", ") + "], " + + showOptBlock(dflt) + ", " + + showBlock(rest) + ")" + Return(res, implct) then + "Return(" + showResult(res) + ", " + showBool(implct) + ")" + Assign(lhs, rhs, rest) then + "Assign(" + showSymbol(lhs) + ", " + showResult(rhs) + ", " + showBlock(rest) + ")" + Define(defn, rest) then + "Define(" + showDefn(defn) + ", " + showBlock(rest) + ")" + End() then "End" + +fun show(x) = + if x is + Symbol then showSymbol(x) + Path then showPath(x) + Result then showResult(x) + Case then showCase(x) + Defn then showDefn(x) + Block then showBlock(x) + else + "" + +fun printCode(x) = print(show(x)) + +fun compile(p: Block) = ??? \ No newline at end of file diff --git a/hkmc2/shared/src/test/mlscript-compile/Shape.mls b/hkmc2/shared/src/test/mlscript-compile/Shape.mls index e69de29bb2..0b724cfcd7 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Shape.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Shape.mls @@ -0,0 +1,120 @@ +import "./Block.mls" +import "./Option.mls" + +open Block { Literal, Symbol, ClassSymbol, showSymbol } +open Option + +type Shape = Shape.Shape + +fun isPrimitiveType(sym: Symbol) = + if sym.name is + "Str" then true + "Int" then true + "Num" then true + "Bool" then true + else false + +fun isPrimitiveTypeOf(sym: Symbol, l: Literal) = + if [sym.name, l] is + ["Str", l] and l is Str then true + ["Int", i] and i is Int then true + ["Num", n] and n is Num then true + ["Bool", b] and b is Bool then true + else false + +module Shape with... + +class Shape with + constructor + Dyn() + Lit(val l: Literal) + Arr(val shapes: Array[Shape], val inf: Bool) + Class(val sym: ClassSymbol, val params: Array[Shape]) + +fun show(s: Shape) = + if s is + Dyn then "Dyn" + Lit(lit) then "Lit(" + Block.showLiteral(lit) + ")" + Arr(shapes, inf) then "Arr([" + shapes.map(show).join(", ") + "], " + inf + ")" + Class(sym, params) then "Class(" + showSymbol(sym) + ", [" + params.map(show).join(", ") + "])" + +fun zipMrg[A](a: Array[A], b: Array[A]): Array[A] = + a.map((a, i, _) => mrg2(a, b.at(i))) + +// TODO: remove, this is no longer in use +fun mrg2(s1: Shape, s2: Shape) = + if s1 == s2 then s1 + else if [s1, s2] is + [Lit(l), Class(sym, params)] + and isPrimitiveTypeOf(sym, l) + then Class(sym, params) + [Class(sym1, ps), Class(sym2, s2)] + and sym1.name == sym2.name + then Class(sym1, ps.map(p => [p.0, zipMrg(p.1, s2)])) + [Arr(s1, false), Arr(s2, false)] + and s1.length == s2.length + then Arr(zipMrg(s1, s2), false) + else Dyn() + +fun mrg(s1: Array[Shape]) = + s1.reduceRight((acc, s, _, _) => mrg2(s, acc)) + +fun sel(s1: Shape, s2: Shape): Array[Shape] = + if [s1, s2] is + [Class(sym, params), Lit(n)] and n is Str + and sym.args is Some(args) + and args.find(_ == n) + == () then [] + is i then [params.(i)] + [Dyn, Lit(n)] and n is Str + then [Dyn()] + [Arr(shapes, false), Lit(n)] and n is Int + then [shapes.(n)] + [Arr(shapes, false), Dyn] then + shapes + [Arr(shapes, true), _] then [Dyn()] // TODO + [Dyn, Lit(n)] and n is Int + then [Dyn()] + [Dyn, Dyn] + then [Dyn()] + else [] // TODO: return no possibility instead of err? + +fun static(s: Shape) = + if s is + Dyn then false + Lit(l) then not (l is Str and isPrimitiveType(l)) // redundant bracket? + Class(_, params) then params.every(static) + Arr(shapes, false) then shapes.every(static) + Arr(shapes, true) then false // TODO + +open Block { Case } + +fun silh(p: Case): Shape = if p is + Block.Lit(l) then Lit(l) + Block.Cls(sym, path) then + val size = if sym.args is Some(i) then i else 0 + Class(sym, Array(size).fill(Dyn)) + Block.Tup(n, inf) then Arr(Array(n).fill(Dyn), inf) + +// TODO: use Option instead, since all of them return at most one shape +fun filter(s: Shape, p: Case): Array[Shape] = + if [s, p] is + [Lit(l1), Block.Lit(l2)] and l1 == l2 then [s] + [Lit(l), Block.Cls(c, _)] and isPrimitiveTypeOf(c, l) then [s] + [Arr(ls, false), Block.Tup(n, false)] and ls.length == n then [s] + [Arr(ls, true), _] then [s] // TODO + [_, Block.Tup(ls, true)] then [s] // TODO + [Class(c1, _), Block.Cls(c2, _)] and c1.name == c2.name then [s] + [Dyn, _] then [silh(p)] + else [] + +fun rest(s: Shape, p: Case): Array[Shape] = + if [s, p] is + [Lit(l1), Block.Lit(l2)] and l1 == l2 then [] + [Lit(l), Block.Cls(c, _)] and isPrimitiveTypeOf(c, l) then [] + [Arr(ls, false), Block.Tup(n, false)] and ls.length == n then [] + [Arr(ls, true), _] then [s] // TODO + [_, Block.Tup(ls, true)] then [s] // TODO + [Class(c1, _), Block.Cls(c2, _)] and c1.name == c2.name then [] + [Dyn, _] then [s] + else [s] \ No newline at end of file diff --git a/hkmc2/shared/src/test/mlscript-compile/ShapeMap.mls b/hkmc2/shared/src/test/mlscript-compile/ShapeMap.mls new file mode 100644 index 0000000000..6b73d8f826 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/ShapeMap.mls @@ -0,0 +1,19 @@ +import "./Option.mls" +import "./ShapeSet.mls" + +open Option { Some, None } +open ShapeSet + +type ShapeMap = ShapeMap.ShapeMap + +module ShapeMap with... + +fun hash(s: ShapeSet) = s.keys().join(", ") + +class ShapeMap(val underlying: Map) with + fun add(s: ShapeSet, code) = underlying.set(hash(s), code) + + fun get(s: ShapeSet) = + if underlying.get(hash(s)) + == () then None + is value then Some of value diff --git a/hkmc2/shared/src/test/mlscript-compile/ShapeSet.mls b/hkmc2/shared/src/test/mlscript-compile/ShapeSet.mls new file mode 100644 index 0000000000..89b0554d89 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/ShapeSet.mls @@ -0,0 +1,89 @@ +import "./Predef.mls" +import "./Block.mls" +import "./Shape.mls" + +open Predef +open Block { ClassSymbol } +open Shape { show } + +type ShapeSet = ShapeSet.ShapeSet + +module ShapeSet with... + +// FIXME: hash is based on uniqueness of pretty printing +fun hash(s: Shape) = show(s) + +fun printShapeSet(s: ShapeSet) = console.log([... s.keys()].join(", ")) + +class ShapeSet(val underlying: Map) with + fun keys() = [...underlying.keys()].toSorted() + + fun values() = underlying.values().toArray() + + fun isEmpty() = underlying.size == 0 + + fun contains(s: Shape) = underlying.has(hash(s)) + + fun flatMap(f) = liftMany(values().flatMap(f)) + +fun create() = ShapeSet(new Map) + +fun lift(s: Shape) = ShapeSet(new Map([[hash(s), s]])) + +fun liftMany(arr: Array[Shape]) = ShapeSet(new Map(arr.map(s => [hash(s), s]))) + +// combining ShapeSet + +fun union(s1: ShapeSet, s2: ShapeSet) = ShapeSet(new Map([...s1.underlying, ...s2.underlying])) + +fun flat(arr: Array[ShapeSet]) = ShapeSet(new Map(arr.map(_.underlying.entries().toArray()).flat())) + +// Cartesian product: https://stackoverflow.com/a/43053803 +fun prod(xs) = + if xs.length == + 0 then [[]] + 1 then xs + else xs.reduce((a, b) => a.flatMap(d => b.map(e => [d, e].flat()))) + +open Shape { Dyn, Lit, Arr, Class } + +// lifted constructors + +fun mkBot() = create() + +fun mkDyn() = lift(Dyn()) + +fun mkLit(l) = lift(Lit(l)) + +fun mkArr(shapes: Array[ShapeSet], inf: Bool) = + shapes + .map(_.underlying.values().toArray()) + |> prod + .map(x => Arr(x, inf)) + |> liftMany + +fun mkClass(sym: ClassSymbol, params: Array[ShapeSet]) = + params.map(_.underlying.values().toArray()) + |> prod + .map(Class(sym, _)) + |> liftMany + +// helper functions + +fun filter(s: ShapeSet, p: Block.Case) = s.flatMap(Shape.filter(_, p)) + +fun rest(s: ShapeSet, p: Block.Case) = s.flatMap(Shape.rest(_, p)) + +fun sel(s1: ShapeSet, s2: ShapeSet) = + prod([s1.values(), s2.values()]) + .flatMap(pair => Shape.sel(pair.0, pair.1)) + |> liftMany + +fun mrg(s1: ShapeSet, s2: ShapeSet) = + mkDyn() // TODO + +open Block { Block } + +fun pruneBadArms(arms: Array[[ShapeSet, Block]]) = + val rem = arms.filter(arm => not (arm.0.isEmpty() and arm.1 is End)) + [flat(rem.map(_.0)), rem.map(_.1)] diff --git a/hkmc2/shared/src/test/mlscript/staging/Functions.mls b/hkmc2/shared/src/test/mlscript/staging/Functions.mls new file mode 100644 index 0000000000..509fa8423f --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/staging/Functions.mls @@ -0,0 +1,96 @@ + +:js +:staging +val x = [1, 2, 3] +staged module Expressions with + fun lit() = 1 + fun assign() = + let x = 42 + let y = x + y + fun tup1() = [1, 2] + fun tup2() = [1, ..x] + fun dynsel() = [1].(0) + fun match1() = + if 9 is + Bool then 1 + 8 then 2 + Int then 3 + 9 then 4 + else 0 + fun match2() = + if [...x] is + [] then 1 + // [1, 2] then 2 // TODO: needs handling for Label, Break + [a, ...] then 3 + else 0 +//│ > Return(Lit(1), false) +//│ > Assign(Symbol("x"), Lit(42), Assign(Symbol("y"), Ref(Symbol("x")), Return(Ref(Symbol("y")), false))) +//│ > Return(Tuple([Lit(1), Lit(2)]), false) +//│ > Return(Tuple([Lit(1), ..Ref(Symbol("x"))]), false) +//│ > Assign(Symbol("tmp"), Tuple([Lit(1)]), Return(DynSelect(Ref(Symbol("tmp")), Lit(0), false), false)) +//│ > Assign(Symbol("scrut"), Lit(9), Match(Ref(Symbol("scrut")), [Cls(ClassSymbol("Bool"), Select(Ref(Symbol("runtime")), Symbol("unreachable"))) -> Return(Lit(1), false), Lit(8) -> Return(Lit(2), false), Cls(ClassSymbol("Int"), Select(Ref(Symbol("runtime")), Symbol("unreachable"))) -> Return(Lit(3), false)], Return(Lit(0), false), End)) +//│ > Assign(Symbol("scrut"), Tuple([...Ref(Symbol("x"))]), Match(Ref(Symbol("scrut")), [Tup(0, false) -> Return(Lit(1), false), Tup(1, true) -> Assign(Symbol("element0$"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("get")), [Ref(Symbol("scrut")), Lit(0)]), Assign(Symbol("middleElements"), Call(Select(Select(Ref(Symbol("runtime")), Symbol("Tuple")), Symbol("slice")), [Ref(Symbol("scrut")), Lit(1), Lit(0)]), Assign(Symbol("a"), Ref(Symbol("element0$")), Return(Lit(3), false))))], Return(Lit(0), false), End)) +//│ x = [1, 2, 3] + +:js +:staging +:fixme +staged module OtherBlocks with + fun breakAndLabel() = + if 1 is + 2 then 0 + 3 then 0 + else 0 +//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing + +:js +:staging +class Outside(a) +staged module ClassInstrumentation with + class Inside(a, b) + class NoArg + fun inst1() = new Outside(1) + fun inst2() = new NoArg + fun app1() = Outside(1) + fun app2() = Inside(1, 2) +//│ > Define(ClsLikeDefn(ClassSymbol("NoArg"), TODO), End) +//│ > Define(ClsLikeDefn(ClassSymbol("Inside":[Symbol("a"), Symbol("b")]), TODO), End) +//│ > Return(Instantiate(Ref(ClassSymbol("Outside":[Symbol("a")])), [Lit(1)]), false) +//│ > Return(Instantiate(Select(Ref(Symbol("ClassInstrumentation")), Symbol("NoArg")), []), false) +//│ > Return(Call(Ref(ClassSymbol("Outside":[Symbol("a")])), [Lit(1)]), false) +//│ > Return(Call(Select(Ref(Symbol("ClassInstrumentation")), Symbol("Inside")), [Lit(1), Lit(2)]), false) + +:js +:staging +staged module Arguments with + fun f(x) = + x = 1 + x + fun g(x)(y, z)() = z +//│ > Assign(Symbol("x"), Lit(1), Return(Ref(Symbol("x")), false)) +//│ > Return(Lit(undefined), false) + +// debug printing fails, collision with class name? +:js +:staging +:fixme +class A() +staged module A with + fun f() = 1 +//│ ═══[RUNTIME ERROR] TypeError: A1.f_gen is not a function + +// debug printing fails, unable to reference the class when calling the instrumented function +:js +:staging +:fixme +module A with + staged module B with + fun f() = 1 +//│ ╔══[COMPILATION ERROR] No definition found in scope for member 'B' +//│ ╟── which references the symbol introduced here +//│ ║ l.88: staged module B with +//│ ║ ^^^^^^ +//│ ║ l.89: fun f() = 1 +//│ ╙── ^^^^^^^^^^^^^^^ +//│ ═══[RUNTIME ERROR] ReferenceError: B is not defined diff --git a/hkmc2/shared/src/test/mlscript/staging/PrintCode.mls b/hkmc2/shared/src/test/mlscript/staging/PrintCode.mls new file mode 100644 index 0000000000..aac4722e59 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/staging/PrintCode.mls @@ -0,0 +1,38 @@ +:staging +:js + +//│ Block = class Block { +//│ Symbol: fun Symbol { class: class Symbol }, +//│ ClassSymbol: fun ClassSymbol { class: class ClassSymbol }, +//│ Arg: fun Arg { class: class Arg }, +//│ Case: class Case, +//│ Lit: fun Lit { class: class Lit }, +//│ Cls: fun Cls { class: class Cls }, +//│ Tup: fun Tup { class: class Tup }, +//│ Result: class Result, +//│ Call: fun Call { class: class Call }, +//│ Instantiate: fun Instantiate { class: class Instantiate }, +//│ Tuple: fun Tuple { class: class Tuple }, +//│ Path: class Path, +//│ Select: fun Select { class: class Select }, +//│ DynSelect: fun DynSelect { class: class DynSelect }, +//│ ValueRef: fun ValueRef { class: class ValueRef }, +//│ ValueLit: fun ValueLit { class: class ValueLit }, +//│ Defn: class Defn, +//│ ValDefn: fun ValDefn { class: class ValDefn }, +//│ ClsLikeDefn: fun ClsLikeDefn { class: class ClsLikeDefn }, +//│ FunDefn: fun FunDefn { class: class FunDefn }, +//│ ClsLikeBody: fun ClsLikeBody { class: class ClsLikeBody }, +//│ Block: class Block, +//│ Match: fun Match { class: class Match }, +//│ Return: fun Return { class: class Return }, +//│ Assign: fun Assign { class: class Assign }, +//│ Define: fun Define { class: class Define }, +//│ End: fun End { class: class End } +//│ } +//│ ShapeSet = class ShapeSet { ShapeSet: fun ShapeSet { class: class ShapeSet } } + +Block.printCode(Block.FunDefn(Block.Symbol("f"), [[Block.Symbol("x")]], Block.Return(Block.ValueLit(1), false))) +//│ > [ Symbol { name: 'x' } ] +//│ > FunDefn(Symbol("f"), ([Symbol("x")]), Return(Lit(1), false), undefined) + diff --git a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls b/hkmc2/shared/src/test/mlscript/staging/Syntax.mls index 94accabd06..e5c4b86216 100644 --- a/hkmc2/shared/src/test/mlscript/staging/Syntax.mls +++ b/hkmc2/shared/src/test/mlscript/staging/Syntax.mls @@ -1,5 +1,7 @@ :pt +:js +// :lot staged module A //│ Parsed tree: //│ Modified: @@ -18,11 +20,6 @@ staged fun f() = 0 :js :slot -:staging -:w staged module A -//│ ╔══[WARNING] `staged` keyword doesn't do anything currently. -//│ ║ l.23: staged module A -//│ ╙── ^^^^^^^^ //│ Pretty Lowered: -//│ define staged class A in set block$res = undefined in end +//│ define staged class A in set block$res1 = undefined in end diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index a134549583..1cc525c074 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -33,7 +33,8 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val runtimeNme = baseScp.allocateName(Elaborator.State.runtimeSymbol) val termNme = baseScp.allocateName(Elaborator.State.termSymbol) val blockNme = baseScp.allocateName(Elaborator.State.blockSymbol) - val shapeNme = baseScp.allocateName(Elaborator.State.shapeSymbol) + val optionNme = baseScp.allocateName(Elaborator.State.optionSymbol) + val shapeSetNme = baseScp.allocateName(Elaborator.State.shapeSetSymbol) val definitionMetadataNme = baseScp.allocateName(Elaborator.State.definitionMetadataSymbol) val prettyPrintNme = baseScp.allocateName(Elaborator.State.prettyPrintSymbol) @@ -61,7 +62,8 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: if importQQ.isSet then importRuntimeModule(termNme, termFile) if stageCode.isSet then importRuntimeModule(blockNme, blockFile) - importRuntimeModule(shapeNme, shapeFile) + importRuntimeModule(optionNme, optionFile) + importRuntimeModule(shapeSetNme, shapeSetFile) h private var hostCreated = false diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index 306d0183c9..c0c2719116 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -21,7 +21,8 @@ abstract class MLsDiffMaker extends DiffMaker: val runtimeFile: io.Path = predefFile.up / "Runtime.mjs" // * Contains MLscript runtime definitions val termFile: io.Path = predefFile.up / "Term.mjs" // * Contains MLscript runtime term definitions val blockFile: io.Path = predefFile.up / "Block.mjs" // * Contains MLscript runtime block definitions - val shapeFile: io.Path = predefFile.up / "Shape.mjs" // * Contains MLscript runtime shape definitions + val optionFile: io.Path = predefFile.up / "Option.mjs" // * Contains MLscipt runtime option definition + val shapeSetFile: io.Path = predefFile.up / "ShapeSet.mjs" // * Contains MLscript runtime shapeset definitions val wd = file.up @@ -168,7 +169,7 @@ abstract class MLsDiffMaker extends DiffMaker: given Config = mkConfig processTrees( PrefixApp(Keywrd(`import`), StrLit(blockFile.toString)) - :: PrefixApp(Keywrd(`import`), StrLit(shapeFile.toString)) + :: PrefixApp(Keywrd(`import`), StrLit(shapeSetFile.toString)) :: Nil) super.init()