Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions kmir/src/kmir/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,13 @@ def _arg_parser() -> ArgumentParser:
action='store_true',
help='Break on every MIR step (statements and terminators)',
)
prove_args.add_argument(
'--break-on-function',
dest='break_on_function',
action='append',
default=None,
help='Break when calling functions / intrinsics matching this name (repeatable)',
)

proof_args = ArgumentParser(add_help=False)
proof_args.add_argument('id', metavar='PROOF_ID', help='The id of the proof to view')
Expand Down Expand Up @@ -638,6 +645,7 @@ def _parse_args(ns: Namespace) -> KMirOpts:
break_every_step=ns.break_every_step,
terminate_on_thunk=ns.terminate_on_thunk,
add_module=ns.add_module,
break_on_function=ns.break_on_function or [],
)
case 'link':
return LinkOpts(
Expand Down
7 changes: 7 additions & 0 deletions kmir/src/kmir/_prove.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _prove_rs(opts: ProveRSOpts, target_path: Path, label: str) -> APRProof:
symbolic=True,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
break_on_function=opts.break_on_function or None,
)
else:
_LOGGER.info(f'Constructing initial proof: {label}')
Expand Down Expand Up @@ -92,6 +93,7 @@ def _prove_rs(opts: ProveRSOpts, target_path: Path, label: str) -> APRProof:
symbolic=True,
haskell_target=opts.haskell_target,
llvm_lib_target=opts.llvm_lib_target,
break_on_function=opts.break_on_function or None,
)

proof = apr_proof_from_smir(
Expand Down Expand Up @@ -122,6 +124,7 @@ def _prove_rs(opts: ProveRSOpts, target_path: Path, label: str) -> APRProof:
break_on_terminator_unreachable=opts.break_on_terminator_unreachable,
break_every_terminator=opts.break_every_terminator,
break_every_step=opts.break_every_step,
break_on_function=opts.break_on_function,
)

if opts.max_workers and opts.max_workers > 1:
Expand Down Expand Up @@ -251,6 +254,7 @@ def _cut_point_rules(
break_on_terminator_unreachable: bool,
break_every_terminator: bool,
break_every_step: bool,
break_on_function: list[str] | None = None,
) -> list[str]:
cut_point_rules = []
if break_on_thunk:
Expand Down Expand Up @@ -291,6 +295,9 @@ def _cut_point_rules(
or break_every_step
):
cut_point_rules.append('KMIR-CONTROL-FLOW.termCallFunction')
if break_on_function:
cut_point_rules.append('KMIR-CONTROL-FLOW.termCallFunctionFilter')
cut_point_rules.append('KMIR-CONTROL-FLOW.termCallIntrinsicFilter')
if break_on_terminator_assert or break_every_terminator or break_every_step:
cut_point_rules.append('KMIR-CONTROL-FLOW.termAssert')
if break_on_terminator_drop or break_every_terminator or break_every_step:
Expand Down
1 change: 1 addition & 0 deletions kmir/src/kmir/kast.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _make_symbolic_call_config(
types: Mapping[Ty, TypeMetadata],
) -> tuple[KInner, list[KInner]]:
locals, constraints = _symbolic_locals(fn_data.args, types)

subst = Subst(
{
'K_CELL': fn_data.call_terminator,
Expand Down
72 changes: 72 additions & 0 deletions kmir/src/kmir/kdist/mir-semantics/kmir.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ See [`rt/configuration.md`](./rt/configuration.md) for a detailed description of
```k
module KMIR-CONTROL-FLOW
imports BOOL
imports COLLECTIONS
imports LIST
imports MAP
imports STRING
imports K-EQUAL
imports MONO
Expand Down Expand Up @@ -325,6 +327,15 @@ where the returned result should go.
=> #execIntrinsic(FUNC, ARGS, DEST, SPAN) ~> #continueAt(TARGET)
</k>
requires isIntrinsicFunction(FUNC)
andBool notBool #functionNameMatchesEnv(getFunctionName(FUNC))
// Intrinsic function call to a function in the break-on set - same as termCallIntrinsic but separate rule id for cut-point
rule [termCallIntrinsicFilter]:
<k> #execTerminatorCall(_, FUNC, ARGS, DEST, TARGET, _UNWIND, SPAN) ~> _
=> #execIntrinsic(FUNC, ARGS, DEST, SPAN) ~> #continueAt(TARGET)
</k>
requires isIntrinsicFunction(FUNC)
andBool #functionNameMatchesEnv(getFunctionName(FUNC))
// Regular function call - full state switching and stack setup
rule [termCallFunction]:
Expand All @@ -342,11 +353,72 @@ where the returned result should go.
</currentFrame>
<stack> STACK => ListItem(StackFrame(OLDCALLER, OLDDEST, OLDTARGET, OLDUNWIND, LOCALS)) STACK </stack>
requires notBool isIntrinsicFunction(FUNC)
andBool notBool #functionNameMatchesEnv(getFunctionName(FUNC))
// Function call to a function in the break-on set - same as termCallFunction but separate rule id for cut-point
rule [termCallFunctionFilter]:
<k> #execTerminatorCall(FTY, FUNC, ARGS, DEST, TARGET, UNWIND, SPAN) ~> _
=> #setUpCalleeData(FUNC, ARGS, SPAN)
</k>
<currentFunc> CALLER => FTY </currentFunc>
<currentFrame>
<currentBody> _ </currentBody>
<caller> OLDCALLER => CALLER </caller>
<dest> OLDDEST => DEST </dest>
<target> OLDTARGET => TARGET </target>
<unwind> OLDUNWIND => UNWIND </unwind>
<locals> LOCALS </locals>
</currentFrame>
<stack> STACK => ListItem(StackFrame(OLDCALLER, OLDDEST, OLDTARGET, OLDUNWIND, LOCALS)) STACK </stack>
requires notBool isIntrinsicFunction(FUNC)
andBool #functionNameMatchesEnv(getFunctionName(FUNC))
syntax Bool ::= isIntrinsicFunction(MonoItemKind) [function]
rule isIntrinsicFunction(IntrinsicFunction(_)) => true
rule isIntrinsicFunction(_) => false [owise]
syntax String ::= getFunctionName(MonoItemKind) [function, total]
//---------------------------------------------------------------
rule getFunctionName(monoItemFn(symbol(NAME), _, _)) => NAME
rule getFunctionName(monoItemStatic(symbol(NAME), _, _)) => NAME
rule getFunctionName(monoItemGlobalAsm(_)) => ""
rule getFunctionName(IntrinsicFunction(symbol(NAME))) => NAME
// Check whether a function name matches any filter in the break-on-functions list.
syntax Bool ::= #functionNameMatchesEnv(String) [function, total]
//----------------------------------------------------------------
rule #functionNameMatchesEnv(NAME) => #functionNameMatchesEnvStr(NAME, #breakOnFunctionsString(0))
// The Int argument is unused; it exists only so the Haskell backend can
// pattern-match on it and not error since zero-argument functions cannot use [owise].
syntax String ::= #breakOnFunctionsString(Int) [function, total, symbol(breakOnFunctionsString)]
//-----------------------------------------------------------------------------------------------
rule #breakOnFunctionsString(_) => "" [owise] // This gets overridden by corresponding python function
syntax Bool ::= #functionNameMatchesEnvStr(String, String) [function, total]
//--------------------------------------------------------------------------
rule #functionNameMatchesEnvStr(_, "") => false
rule #functionNameMatchesEnvStr(NAME, ENV) => #functionNameMatchesAnyList(NAME, #splitSemicolon(ENV))
requires ENV =/=String ""
syntax List ::= #splitSemicolon(String) [function, total]
//--------------------------------------------------------
rule #splitSemicolon(S) => #splitSemicolonAux(S, findString(S, ";", 0))
syntax List ::= #splitSemicolonAux(String, Int) [function, total]
//-----------------------------------------------------------------
rule #splitSemicolonAux(S, -1) => ListItem(S)
rule #splitSemicolonAux(S, I) =>
ListItem(substrString(S, 0, I)) #splitSemicolon(substrString(S, I +Int 1, lengthString(S)))
requires I >=Int 0
syntax Bool ::= #functionNameMatchesAnyList(String, List) [function, total]
//-------------------------------------------------------------------------
rule #functionNameMatchesAnyList(_, .List) => false
rule #functionNameMatchesAnyList(NAME, ListItem(FILTER:String) REST) =>
0 <=Int findString(NAME, FILTER, 0) orBool #functionNameMatchesAnyList(NAME, REST)
rule #functionNameMatchesAnyList(_, _) => false [owise]
syntax KItem ::= #continueAt(MaybeBasicBlockIdx)
rule <k> #continueAt(someBasicBlockIdx(TARGET)) => #execBlockIdx(TARGET) ... </k>
rule <k> #continueAt(noBasicBlockIdx) => .K ... </k>
Expand Down
2 changes: 2 additions & 0 deletions kmir/src/kmir/kmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def from_kompiled_kore(
llvm_target: str | None = None,
llvm_lib_target: str | None = None,
haskell_target: str | None = None,
break_on_function: list[str] | None = None,
) -> KMIR:
from .kompile import kompile_smir

Expand All @@ -75,6 +76,7 @@ def from_kompiled_kore(
llvm_target=llvm_target,
llvm_lib_target=llvm_lib_target,
haskell_target=haskell_target,
break_on_function=break_on_function,
)
return kompiled_smir.create_kmir(bug_report_file=bug_report)

Expand Down
42 changes: 39 additions & 3 deletions kmir/src/kmir/kompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class KompileDigest:
llvm_target: str
llvm_lib_target: str
haskell_target: str
break_on_function: str

@staticmethod
def load(target_dir: Path) -> KompileDigest:
Expand All @@ -80,6 +81,7 @@ def load(target_dir: Path) -> KompileDigest:
llvm_target=data['llvm-target'],
llvm_lib_target=data['llvm-lib-target'],
haskell_target=data['haskell-target'],
break_on_function=data.get('break-on-function', ''),
)

def write(self, target_dir: Path) -> None:
Expand All @@ -91,6 +93,7 @@ def write(self, target_dir: Path) -> None:
'llvm-target': self.llvm_target,
'llvm-lib-target': self.llvm_lib_target,
'haskell-target': self.haskell_target,
'break-on-function': self.break_on_function,
},
),
)
Expand Down Expand Up @@ -205,6 +208,7 @@ def kompile_smir(
llvm_target: str | None = None,
llvm_lib_target: str | None = None,
haskell_target: str | None = None,
break_on_function: list[str] | None = None,
) -> KompiledSMIR:
kompile_digest: KompileDigest | None = None
try:
Expand All @@ -222,6 +226,7 @@ def kompile_smir(
llvm_target=llvm_target,
llvm_lib_target=llvm_lib_target,
haskell_target=haskell_target,
break_on_function=';'.join(break_on_function) if break_on_function else '',
)

target_hs_path = target_dir / 'haskell'
Expand All @@ -242,7 +247,7 @@ def kompile_smir(

haskell_def_dir = kdist.which(haskell_target)
kmir = KMIR(haskell_def_dir)
smir_rules: list[Sentence] = list(make_kore_rules(kmir, smir_info))
smir_rules: list[Sentence] = list(make_kore_rules(kmir, smir_info, break_on_function=break_on_function))
_LOGGER.info(f'Generated {len(smir_rules)} function equations to add to `definition.kore')

# Load and convert extra module rules if provided
Expand Down Expand Up @@ -437,7 +442,9 @@ def _make_stratified_rules(
return [*declarations, *dispatch, *defaults, *equations]


def make_kore_rules(kmir: KMIR, smir_info: SMIRInfo) -> Sequence[Sentence]:
def make_kore_rules(
kmir: KMIR, smir_info: SMIRInfo, *, break_on_function: list[str] | None = None
) -> Sequence[Sentence]:
# kprint tool is too chatty
kprint_logger = logging.getLogger('pyk.ktool.kprint')
kprint_logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -489,7 +496,12 @@ def get_int_arg(app: KInner) -> int:
kmir, 'lookupAlloc', 'AllocId', 'Evaluation', 'allocId', allocs, invalid_alloc_n
)

return [*equations, *type_equations, *alloc_equations]
# Generate break-on-function filter rule if filters are provided
break_on_rules: list[Axiom] = []
if break_on_function:
break_on_rules.append(_mk_break_on_functions_rule(kmir, break_on_function))

return [*equations, *type_equations, *alloc_equations, *break_on_rules]


def _functions(kmir: KMIR, smir_info: SMIRInfo) -> dict[int, KInner]:
Expand Down Expand Up @@ -544,6 +556,30 @@ def _mk_equation(kmir: KMIR, fun: str, arg: KInner, arg_sort: str, result: KInne
return rule.to_axiom()


def _mk_break_on_functions_rule(kmir: KMIR, break_on_function: list[str]) -> Axiom:
"""Generate Kore rule for filtering function breaks: `#breakOnFunctionsString(0) => "filter1;filter2;..."`"""
from pyk.kore.prelude import int_dv
from pyk.kore.rule import FunctionRule

filter_string = ';'.join(break_on_function)
fun_app = App('LblbreakOnFunctionsString', (), (int_dv(0),))
result_kore = kmir.kast_to_kore(stringToken(filter_string), KSort('String'))

rule = FunctionRule(
lhs=fun_app,
rhs=result_kore,
req=None,
ens=None,
sort=SortApp('SortString'),
arg_sorts=(SortApp('SortInt'),),
anti_left=None,
priority=50,
uid='breakOnFunctionsString-generated',
label='breakOnFunctionsString-generated',
)
return rule.to_axiom()


def _decode_alloc(smir_info: SMIRInfo, raw_alloc: Any) -> tuple[KInner, KInner]:
from .decoding import UnableToDecodeValue, decode_alloc_or_unable

Expand Down
5 changes: 5 additions & 0 deletions kmir/src/kmir/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ProveOpts(KMirOpts):
break_every_terminator: bool
break_every_step: bool
terminate_on_thunk: bool
break_on_function: list[str]

def __init__(
self,
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
break_every_terminator: bool = False,
break_every_step: bool = False,
terminate_on_thunk: bool = False,
break_on_function: list[str] | None = None,
) -> None:
self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None
self.haskell_target = haskell_target
Expand All @@ -138,6 +140,7 @@ def __init__(
self.break_every_terminator = break_every_terminator
self.break_every_step = break_every_step
self.terminate_on_thunk = terminate_on_thunk
self.break_on_function = break_on_function if break_on_function is not None else []


@dataclass
Expand Down Expand Up @@ -182,6 +185,7 @@ def __init__(
break_every_step: bool = False,
terminate_on_thunk: bool = False,
add_module: Path | None = None,
break_on_function: list[str] | None = None,
) -> None:
self.rs_file = rs_file
self.proof_dir = Path(proof_dir).resolve() if proof_dir is not None else None
Expand Down Expand Up @@ -213,6 +217,7 @@ def __init__(
self.break_every_step = break_every_step
self.terminate_on_thunk = terminate_on_thunk
self.add_module = add_module
self.break_on_function = break_on_function if break_on_function is not None else []


@dataclass
Expand Down
15 changes: 15 additions & 0 deletions kmir/src/tests/integration/data/prove-rs/break-on-function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#![feature(core_intrinsics)]

fn foo() {
let x = std::hint::black_box(42);
bar();
assert!(x == 42);
}

fn bar() {
std::intrinsics::assert_inhabited::<i32>();
}

fn main() {
foo();
}
Loading