diff --git a/docs/changelog.md b/docs/changelog.md index 11155184d..23b7dfc41 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,21 @@ small.label { } +## 4.22.0 (Jan 3, 2026) { id="4.22.0" } + +Features & improvements: + +- added support for Solidity 0.8.31, 0.8.32 and 0.8.33 [core] +- improved warning compilation messages when files cannot be compiled together due to version constraints [core] +- improved warning compilation messages when a file from a subproject is included into a different subproject [core] +- limited the maximum number of running solc instances in parallel to CPU count [core] +- documented how to use contract-defined invariants in fuzz tests [documentation] + +Fixes: + +- `incremental` compilation setting is now respected on compilation JSON import & export [core] +- fixed `read_storage_variable` and `write_storage_variable` when working with contract and enum types [testing framework] + ## 4.21.0 (Nov 17, 2025) { id="4.21.0" } Features & improvements: diff --git a/docs/testing-framework/fuzzing.md b/docs/testing-framework/fuzzing.md index 55d35fd42..64dd3d1cf 100644 --- a/docs/testing-framework/fuzzing.md +++ b/docs/testing-framework/fuzzing.md @@ -69,6 +69,38 @@ def invariant_count(self) -> None: An optional `period` argument can be passed to the `@invariant` decorator. If specified, the invariant is executed only after every `period` flows. +#### Contract-defined invariants (Echidna / Medusa / Foundry style) + +In addition to Python-defined invariants, invariants implemented directly in contracts (in the style of Echidna / Medusa / Foundry) can be reused by calling those functions from Python. + +For example, the following Solidity contract can be used: + +```solidity +contract Counter { + uint256 public totalSupply; + bool public overflowed; + + // Reverts on violation + function echidna_total_supply_invariant() public view { + require(totalSupply <= 1_000_000, "totalSupply too high"); + } + + // Returns false on violation + function invariant_no_overflow() public view returns (bool) { + return !overflowed; + } +} +``` + +These contract-level invariants can be wired into a Wake fuzz test as follows: + +```python +@invariant() +def invariant_contract_invariants(self) -> None: + self.counter.echidna_total_supply_invariant() # reverts on violation + assert self.counter.invariant_no_overflow() # returns bool +``` + ### Execution hooks Execution hooks are functions that are executed during the `FuzzTest` lifecycle. This is the list of all available execution hooks: diff --git a/wake/cli/__main__.py b/wake/cli/__main__.py index 235a5be0b..5ad6f6a0b 100644 --- a/wake/cli/__main__.py +++ b/wake/cli/__main__.py @@ -21,6 +21,7 @@ from .run import run_run from .svm import run_svm from .test import run_test +from .mutate import run_mutate if platform.system() != "Windows": try: @@ -232,7 +233,7 @@ def exit(): main.add_command(run_run) main.add_command(run_svm) main.add_command(run_test) - +main.add_command(run_mutate) @main.command(name="config") @click.pass_context diff --git a/wake/cli/compile.py b/wake/cli/compile.py index 655e23d25..d171de860 100644 --- a/wake/cli/compile.py +++ b/wake/cli/compile.py @@ -32,6 +32,7 @@ def export_json( out = { "version": build_info.wake_version, + "incremental": build_info.incremental, "system": platform.system(), "project_root": str(config.project_root_path), "wake_contracts_path": str(config.wake_contracts_path), @@ -93,6 +94,9 @@ async def compile( wake_contracts_path = PurePosixPath(loaded["wake_contracts_path"]) original_project_root = PurePosixPath(loaded["project_root"]) + if incremental is None and loaded.get("incremental", None) is not None: + incremental = loaded["incremental"] + config = WakeConfig.fromdict( loaded["config"], wake_contracts_path=wake_contracts_path, diff --git a/wake/cli/detect.py b/wake/cli/detect.py index cfa8d2375..efc4bb3d1 100644 --- a/wake/cli/detect.py +++ b/wake/cli/detect.py @@ -707,6 +707,7 @@ def process_detection(detection: Detection) -> Dict[str, Any]: sys.exit(0 if len(all_detections) == 0 else 3) if import_json is None: + incremental = None scan_extra = {} sol_files: Set[Path] = set() modified_files: Dict[Path, bytes] = {} @@ -746,6 +747,8 @@ def process_detection(detection: Detection) -> Dict[str, Any]: wake_contracts_path = PurePosixPath(loaded["wake_contracts_path"]) original_project_root = PurePosixPath(loaded["project_root"]) + incremental = loaded.get("incremental", None) + config = WakeConfig.fromdict( loaded["config"], wake_contracts_path=wake_contracts_path, @@ -820,6 +823,7 @@ def process_detection(detection: Detection) -> Dict[str, Any]: console=console, no_warnings=True, modified_files=modified_files, + incremental=incremental, ) assert compiler.latest_build_info is not None diff --git a/wake/cli/mutate.py b/wake/cli/mutate.py new file mode 100644 index 000000000..c74874a9a --- /dev/null +++ b/wake/cli/mutate.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +import asyncio +import subprocess +import sys +import time +from enum import Enum, auto +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Tuple + +import rich_click as click +from rich.console import Console +from rich.table import Table +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn + +if TYPE_CHECKING: + from wake.config import WakeConfig + from wake.mutators.api import Mutation, Mutator + + +class TestResult(Enum): + """Result of running tests on a mutation.""" + PASSED = auto() + FAILED = auto() + COMPILE_ERROR = auto() + TIMEOUT = auto() + + +def split_csv(ctx, param, value) -> List[str]: + """Callback to split space/comma-separated values and flatten.""" + if not value: + return [] + result = [] + for item in value: + # Split on both comma and space + for v in item.replace(",", " ").split(): + v = v.strip() + if v: + result.append(v) + return result + + +def split_csv_paths(ctx, param, value) -> List[Path]: + """Callback to split space/comma-separated paths and flatten.""" + if not value: + return [] + result = [] + for item in value: + # Split on both comma and space + for v in item.replace(",", " ").split(): + v = v.strip() + if v: + p = Path(v) + if not p.exists(): + raise click.BadParameter(f"Path does not exist: {v}") + result.append(p) + return result + + +def discover_mutators() -> dict[str, type["Mutator"]]: + """Discover all mutator classes from wake_mutators package.""" + import importlib + import pkgutil + + import wake_mutators as mutators_pkg + from wake.mutators.api import Mutator + + found = {} + + for importer, modname, ispkg in pkgutil.iter_modules(mutators_pkg.__path__): + if modname.startswith("_"): + continue + + module = importlib.import_module(f"wake_mutators.{modname}") + + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + isinstance(attr, type) + and issubclass(attr, Mutator) + and attr is not Mutator + and attr.__module__ == module.__name__ + ): + found[attr.name] = attr + + return found + + +async def compile_project(config: "WakeConfig", contract_paths: List[Path]): + """Compile the project and return the build.""" + from wake.compiler.compiler import SolidityCompiler + from wake.compiler.solc_frontend import SolcOutputSelectionEnum + + compiler = SolidityCompiler(config) + compiler.load() + + build, errors = await compiler.compile( + contract_paths, + [SolcOutputSelectionEnum.AST], + write_artifacts=False, + ) + + return build + + +def collect_mutations( + config: "WakeConfig", + contract_paths: List[Path], + mutator_classes: List[type["Mutator"]], + console: Console, +) -> List["Mutation"]: + """Run mutators on contracts to collect all mutations.""" + from wake.core.visitor import visit_map, group_map + + start = time.perf_counter() + with console.status("[bold green]Compiling contracts...[/]"): + build = asyncio.run(compile_project(config, contract_paths)) + end = time.perf_counter() + console.log(f"[green]Compiled in [bold green]{end - start:.2f} s[/bold green][/]") + + all_mutations = [] + + start = time.perf_counter() + with console.status("[bold green]Collecting mutations...[/]"): + for mutator_cls in mutator_classes: + mutator = mutator_cls() + + for path in contract_paths: + source_unit = build.source_units.get(path) + if source_unit is None: + continue + + mutator._current_file = path + + for node in source_unit: + if node.ast_node.node_type in group_map: + for group in group_map[node.ast_node.node_type]: + if group in visit_map: + visit_map[group](mutator, node) + + if node.ast_node.node_type in visit_map: + visit_map[node.ast_node.node_type](mutator, node) + + all_mutations.extend(mutator.mutations) + + end = time.perf_counter() + console.log(f"[green]Found [bold green]{len(all_mutations)}[/bold green] mutation(s) in [bold green]{end - start:.2f} s[/bold green][/]") + + return all_mutations + + +def run_tests(test_paths: List[str], timeout: int = 120) -> TestResult: + """Regenerate pytypes and run wake test. Return TestResult.""" + try: + compile_result = subprocess.run( + ["wake", "up", "pytypes"], + capture_output=True, + timeout=timeout, + cwd=Path.cwd(), + ) + + if compile_result.returncode != 0: + return TestResult.COMPILE_ERROR + except subprocess.TimeoutExpired: + return TestResult.TIMEOUT + + cmd = ["wake", "test", "-x"] + test_paths + + try: + result = subprocess.run( + cmd, + capture_output=True, + timeout=timeout, + cwd=Path.cwd(), + ) + return TestResult.PASSED if result.returncode == 0 else TestResult.FAILED + except subprocess.TimeoutExpired: + return TestResult.TIMEOUT + + +@click.command(name="mutate") +@click.option( + "--contracts", + "-c", + multiple=True, + type=str, + callback=split_csv_paths, + help="Contract files to mutate (space or comma-separated).", +) +@click.option( + "--mutations", + "-m", + multiple=True, + type=str, + callback=split_csv, + help="Mutation operators to use (space or comma-separated). Default: all.", +) +@click.option( + "--list-mutations", + is_flag=True, + default=False, + help="List available mutation operators.", +) +@click.option( + "--timeout", + "-t", + type=int, + default=60, + help="Timeout for each test run in seconds.", +) +@click.option( + "-v", + "--verbosity", + default=0, + count=True, + help="Increase verbosity.", +) +@click.argument("test_paths", nargs=-1, type=click.Path(exists=True)) +@click.pass_context +def run_mutate( + context: click.Context, + contracts: List[Path], + mutations: List[str], + list_mutations: bool, + timeout: int, + verbosity: int, + test_paths: Tuple[str, ...], +) -> None: + """ Run mutation testing on Solidity contracts. """ + from wake.config import WakeConfig + from wake.mutators.api import MutantStatus + + console = Console() + + # Discover available mutators + available_mutators = discover_mutators() + + if list_mutations: + table = Table(title="Available Mutation Operators") + table.add_column("Name", style="cyan") + table.add_column("Description") + + for name, cls in sorted(available_mutators.items()): + table.add_row(name, cls.description) + + console.print(table) + return + + # Now contracts is required (but not for --list-mutations) + if not contracts: + raise click.BadParameter("--contracts/-c is required when running mutation tests.") + + # Default to all tests if none specified + if test_paths: + test_paths_list = list(test_paths) + else: + tests_dir = Path.cwd() / "tests" + if tests_dir.exists(): + test_paths_list = [str(p) for p in tests_dir.glob("test_*.py")] + else: + test_paths_list = [] + + if not test_paths_list: + raise click.BadParameter("No test files found. Specify test paths or create tests/test_*.py files.") + + console.log(f"[green]Auto-discovered [bold green]{len(test_paths_list)}[/bold green] test file(s)[/]") + + # Select mutators + if mutations: + selected = [] + for m in mutations: + if m not in available_mutators: + raise click.BadParameter(f"Unknown mutation operator: {m}") + selected.append(available_mutators[m]) + else: + selected = list(available_mutators.values()) + + # Load config + config = WakeConfig(local_config_path=context.obj.get("local_config_path", None)) + config.load_configs() + + # Resolve contract paths + resolved_contracts = [p.resolve() for p in contracts] + + # Print configuration + console.print() + console.rule("[bold]Mutation Testing Configuration[/bold]") + console.print(f" [dim]Contracts:[/dim] {', '.join(str(c.name) for c in contracts)}") + console.print(f" [dim]Mutations:[/dim] {', '.join(m.name for m in selected)}") + console.print(f" [dim]Tests:[/dim] {', '.join(test_paths_list)}") + console.print(f" [dim]Timeout:[/dim] {timeout}s") + console.print() + + # Collect all mutations + all_mutations = collect_mutations(config, resolved_contracts, selected, console) + + if not all_mutations: + console.print("[yellow]No mutations found.[/yellow]") + return + + # Run baseline test first + console.print() + with console.status("[bold green]Running baseline tests...[/]"): + start = time.perf_counter() + baseline_result = run_tests(test_paths_list, timeout) + end = time.perf_counter() + + if baseline_result != TestResult.PASSED: + console.log(f"[bold red]Baseline tests failed![/bold red] Fix tests before mutation testing.") + sys.exit(1) + + console.log(f"[green]Baseline tests passed in [bold green]{end - start:.2f} s[/bold green][/]") + console.print() + + # Test each mutation + results = [] + + console.rule("[bold]Running Mutations[/bold]") + console.print() + + total_start = time.perf_counter() + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + transient=True, + ) as progress: + task = progress.add_task("[cyan]Testing mutations...", total=len(all_mutations)) + + for i, mutation in enumerate(all_mutations, 1): + progress.update(task, description=f"[cyan]Testing mutation {i}/{len(all_mutations)}") + + original_source = mutation.file_path.read_bytes() + + try: + mutated_source = mutation.apply(original_source) + mutation.file_path.write_bytes(mutated_source) + + test_result = run_tests(test_paths_list, timeout) + + if test_result == TestResult.PASSED: + status = MutantStatus.SURVIVED + elif test_result == TestResult.FAILED: + status = MutantStatus.KILLED + elif test_result == TestResult.COMPILE_ERROR: + status = MutantStatus.COMPILE_ERROR + elif test_result == TestResult.TIMEOUT: + status = MutantStatus.TIMEOUT + + results.append((mutation, status)) + + except Exception as e: + results.append((mutation, MutantStatus.COMPILE_ERROR)) + + finally: + mutation.file_path.write_bytes(original_source) + + progress.advance(task) + + total_end = time.perf_counter() + + # Print detailed results + if verbosity > 0: + console.print() + results_table = Table(title="Mutation Results", show_lines=True) + results_table.add_column("#", style="dim", width=4) + results_table.add_column("File", style="cyan") + results_table.add_column("Line", style="dim", width=6) + results_table.add_column("Mutation", style="white") + results_table.add_column("Status", justify="center") + + for i, (mutation, status) in enumerate(results, 1): + if status == MutantStatus.KILLED: + status_str = "[green]KILLED ✓[/green]" + elif status == MutantStatus.SURVIVED: + status_str = "[red]SURVIVED ✗[/red]" + elif status == MutantStatus.COMPILE_ERROR: + status_str = "[yellow]COMPILE ERROR ⚠[/yellow]" + elif status == MutantStatus.TIMEOUT: + status_str = "[yellow]TIMEOUT ⏱[/yellow]" + + results_table.add_row( + str(i), + mutation.file_path.name, + str(mutation.line_number), + mutation.description, + status_str, + ) + + console.print(results_table) + + # Summary + killed = sum(1 for _, s in results if s == MutantStatus.KILLED) + survived = sum(1 for _, s in results if s == MutantStatus.SURVIVED) + compile_errors = sum(1 for _, s in results if s == MutantStatus.COMPILE_ERROR) + timeouts = sum(1 for _, s in results if s == MutantStatus.TIMEOUT) + + console.print() + console.rule("[bold]Mutation Testing Summary[/bold]") + console.print() + + summary_table = Table(show_header=False, box=None, padding=(0, 2)) + summary_table.add_column("Label", style="dim") + summary_table.add_column("Value", justify="right") + + summary_table.add_row("Total mutations", str(len(results))) + summary_table.add_row("Killed", f"[green]{killed}[/green]") + summary_table.add_row("Survived", f"[red]{survived}[/red]") + if compile_errors: + summary_table.add_row("Compile Errors", f"[yellow]{compile_errors}[/yellow]") + if timeouts: + summary_table.add_row("Timeouts", f"[yellow]{timeouts}[/yellow]") + summary_table.add_row("Time elapsed", f"{total_end - total_start:.2f} s") + + console.print(summary_table) + + if killed + survived > 0: + score = killed / (killed + survived) * 100 + console.print() + if score >= 80: + console.print(f"[bold green]Mutation Score: {score:.1f}%[/bold green]") + elif score >= 50: + console.print(f"[bold yellow]Mutation Score: {score:.1f}%[/bold yellow]") + else: + console.print(f"[bold red]Mutation Score: {score:.1f}%[/bold red]") + + # List survivors + survivors = [(m, s) for m, s in results if s == MutantStatus.SURVIVED] + if survivors: + console.print() + console.rule("[bold red]Surviving Mutations[/bold red]") + console.print() + console.print("[dim]These mutations were not caught by tests - consider improving test coverage:[/dim]") + console.print() + + survivor_table = Table(show_header=True, box=None) + survivor_table.add_column("File", style="cyan") + survivor_table.add_column("Line", style="dim") + survivor_table.add_column("Mutation") + + for mutation, _ in survivors: + survivor_table.add_row( + mutation.file_path.name, + str(mutation.line_number), + mutation.description, + ) + + console.print(survivor_table) + + # Final status + console.print() + if survived == 0: + console.print("[bold green]All mutations were killed! ✓[/bold green]") + else: + console.print(f"[bold red]{survived} mutation(s) survived - tests need improvement[/bold red]") + + sys.exit(0 if survived == 0 else 1) \ No newline at end of file diff --git a/wake/compiler/compiler.py b/wake/compiler/compiler.py index b69730b00..627df2828 100644 --- a/wake/compiler/compiler.py +++ b/wake/compiler/compiler.py @@ -316,6 +316,7 @@ class SolidityCompiler: __solc_frontend: SolcFrontend __source_unit_name_resolver: SourceUnitNameResolver __source_path_resolver: SourcePathResolver + __solc_semaphore: asyncio.Semaphore _latest_build_info: Optional[ProjectBuildInfo] _latest_build: Optional[ProjectBuild] @@ -330,6 +331,7 @@ def __init__(self, wake_config: WakeConfig): self.__solc_frontend = SolcFrontend(wake_config) self.__source_unit_name_resolver = SourceUnitNameResolver(wake_config) self.__source_path_resolver = SourcePathResolver(wake_config) + self.__solc_semaphore = asyncio.Semaphore(os.cpu_count() or 4) self._latest_build_info = None self._latest_build = None @@ -508,17 +510,16 @@ def build_compilation_units_maximize( Builds a list of compilation units from a graph. Number of compilation units is maximized. """ - def __build_compilation_unit( - graph: nx.DiGraph, start: Iterable[str], subproject: Optional[str] - ) -> CompilationUnit: + def __build_compilation_unit(graph: nx.DiGraph, start: str) -> CompilationUnit: + subproject = graph.nodes[start]["subproject"] nodes_subset = set() - nodes_queue: deque[str] = deque(start) + nodes_queue: deque[tuple[str, Optional[str]]] = deque([(start, None)]) versions: SolidityVersionRanges = SolidityVersionRanges( [SolidityVersionRange(None, None, None, None)] ) while len(nodes_queue) > 0: - node = nodes_queue.pop() + node, importing_node = nodes_queue.pop() if node in nodes_subset: continue @@ -527,18 +528,24 @@ def __build_compilation_unit( versions &= graph.nodes[node]["versions"] compiled_with[node].add(subproject) + subproject_mismatch = False if graph.nodes[node]["subproject"] not in { subproject, None, } and not node.startswith("wake/"): logger.warning( - f"Including file {node} belonging to subproject '{graph.nodes[node]['subproject'] or ''}' into compilation of subproject '{subproject or ''}'" + f"Including file '{node}' belonging to subproject '{graph.nodes[node]['subproject'] or ''}' into compilation of subproject '{subproject or ''}' due to import from '{importing_node}'" ) + subproject_mismatch = True for in_edge in graph.in_edges(node): - _from, to = in_edge + _from, _ = in_edge if _from not in nodes_subset: - nodes_queue.append(_from) + # propagate original importing node if there is a subproject mismatch, otherwise propagate the current node + # this ensures that the original node causing whole subgraph subproject mismatch is reported + nodes_queue.append( + (_from, importing_node if subproject_mismatch else node) + ) subgraph = graph.subgraph(nodes_subset).copy() return CompilationUnit(subgraph, versions, subproject) @@ -553,9 +560,7 @@ def __build_compilation_unit( for sink in sinks: subproject = graph.nodes[sink]["subproject"] if subproject not in compiled_with[sink]: - compilation_unit = __build_compilation_unit( - graph, [sink], subproject - ) + compilation_unit = __build_compilation_unit(graph, sink) compilation_units.append(compilation_unit) graph.remove_node(sink) @@ -582,16 +587,13 @@ def __build_compilation_unit( break if is_closed_cycle: - subprojects = { - graph.nodes[node]["subproject"] for node in simple_cycle + # choose one representative node for each subproject + nodes_by_subproject = { + graph.nodes[node]["subproject"]: node for node in simple_cycle } - for subproject in subprojects: - if any( - subproject not in compiled_with[n] for n in simple_cycle - ): - compilation_unit = __build_compilation_unit( - graph, simple_cycle, subproject - ) + for subproject, node in nodes_by_subproject.items(): + if subproject not in compiled_with[node]: + compilation_unit = __build_compilation_unit(graph, node) compilation_units.append(compilation_unit) generated_cycles.add(frozenset(simple_cycle)) @@ -691,6 +693,7 @@ def optimize_build_settings( def determine_solc_versions( self, + graph: nx.DiGraph, compilation_units: Iterable[CompilationUnit], target_versions_by_subproject: Mapping[ Optional[str], Optional[SolidityVersion] @@ -704,6 +707,10 @@ def determine_solc_versions( target_version = target_versions_by_subproject.get( compilation_unit.subproject ) + files_str = "\n".join( + f" {su}: {graph.nodes[su]['versions']}" + for su in compilation_unit.source_unit_names + ) if all( is_relative_to(f, self.__config.wake_contracts_path) for f in compilation_unit.files @@ -712,7 +719,6 @@ def determine_solc_versions( if target_version is not None: if target_version not in compilation_unit.versions: - files_str = "\n".join(str(path) for path in compilation_unit.files) logger.warning( f"Unable to compile following files with solc version `{target_version}` set in config files:\n" + files_str @@ -727,7 +733,6 @@ def determine_solc_versions( if version in compilation_unit.versions ] if len(matching_versions) == 0: - files_str = "\n".join(str(path) for path in compilation_unit.files) logger.warning( f"Unable to compile following files with any solc version:\n" + files_str @@ -741,7 +746,6 @@ def determine_solc_versions( if version <= max_version ) except StopIteration: - files_str = "\n".join(str(path) for path in compilation_unit.files) logger.warning( f"The maximum supported version of Solidity is {max_version}, unable to compile the following files:\n" + files_str @@ -749,7 +753,6 @@ def determine_solc_versions( skipped_compilation_units.append(compilation_unit) continue if target_version < min_version: - files_str = "\n".join(str(path) for path in compilation_unit.files) logger.warning( f"The minimum supported version of Solidity is {min_version}, unable to compile the following files:\n" + files_str @@ -1142,7 +1145,7 @@ async def compile( ) | set(files_to_compile) target_versions, skipped_compilation_units = self.determine_solc_versions( - compilation_units, target_versions_by_subproject + graph, compilation_units, target_versions_by_subproject ) await self._install_solc(target_versions, console) @@ -1655,6 +1658,7 @@ async def compile_unit_raw( return SolcOutput() # run the solc executable - return await self.__solc_frontend.compile( - files, sources, target_version, build_settings - ) + async with self.__solc_semaphore: + return await self.__solc_frontend.compile( + files, sources, target_version, build_settings + ) diff --git a/wake/compiler/source_unit_name_resolver.py b/wake/compiler/source_unit_name_resolver.py index 8ec32e615..7390a0800 100644 --- a/wake/compiler/source_unit_name_resolver.py +++ b/wake/compiler/source_unit_name_resolver.py @@ -37,17 +37,10 @@ def apply_remapping(self, parent_source_unit: str, source_unit_name: str) -> str if len(matching_remappings) == 0: return source_unit_name - # longest prefix wins, if there are multiple remappings with the same prefix, choose the last one - matching_remappings.sort(key=lambda r: len(r.prefix), reverse=True) - - # choose the remapping with the longest context - # if there are multiple remappings with the same context, choose the last one - l = len(matching_remappings[0].prefix) + # choose longest context, then longest prefix, then last one wins + # https://github.com/argotorg/solidity/blob/ed1f49b4739f065b589616250602071536e26c3a/libsolidity/interface/ImportRemapper.cpp#L42-L69 + matching_remappings.sort(key=lambda r: (len(r.context or ""), len(r.prefix))) target_remapping = matching_remappings[-1] - for i in range(1, len(matching_remappings)): - if len(matching_remappings[i].prefix) != l: - target_remapping = matching_remappings[i - 1] - break return source_unit_name.replace( str(target_remapping.prefix), target_remapping.target or "", 1 diff --git a/wake/config/wake_config.py b/wake/config/wake_config.py index 01ae4f82f..276068d34 100644 --- a/wake/config/wake_config.py +++ b/wake/config/wake_config.py @@ -536,7 +536,7 @@ def max_solidity_version(self) -> SolidityVersion: Returns: Maximum supported Solidity version. """ - return SolidityVersion.fromstring("0.8.30") + return SolidityVersion.fromstring("0.8.33") @property def detectors(self) -> DetectorsConfig: diff --git a/wake/development/utils.py b/wake/development/utils.py index 4ae6f584b..99fe24217 100644 --- a/wake/development/utils.py +++ b/wake/development/utils.py @@ -31,6 +31,7 @@ from urllib.request import Request, urlopen import eth_utils +from Crypto.Hash import keccak from pydantic import TypeAdapter, ValidationError from wake_rs import keccak256 @@ -449,9 +450,16 @@ def _get_storage_value( type_info.value, types, ) + elif type_name.startswith("t_userDefinedValueType"): + raise ValueError(f"User defined value types are not supported") else: data = data.rjust(32, b"\x00") + if type_name.startswith("t_contract"): + type_name = "t_address" + elif type_name.startswith("t_enum"): + type_name = "t_uint8" + return Abi.decode([type_name[2:]], data)[0] if storage_layout_contract is None: @@ -696,13 +704,21 @@ def _set_storage_value( type_info.value, types, ) + elif type_name.startswith("t_userDefinedValueType"): + raise ValueError(f"User defined value types are not supported") else: original_data = bytearray( contract.chain.chain_interface.get_storage_at( str(contract.address), slot, "pending" ) ) - encoded_value = Abi.encode_packed([type_name[2:]], [value]) + if type_name.startswith("t_contract"): + encoded_value = Abi.encode_packed(["address"], [value]) + elif type_name.startswith("t_enum"): + encoded_value = Abi.encode_packed(["uint8"], [value]) + else: + encoded_value = Abi.encode_packed([type_name[2:]], [value]) + original_data[ -offset - type_info.number_of_bytes : (-offset if offset != 0 else None) ] = encoded_value diff --git a/wake/lsp/lsp_compiler.py b/wake/lsp/lsp_compiler.py index 66ea8c7a4..a6125bf89 100644 --- a/wake/lsp/lsp_compiler.py +++ b/wake/lsp/lsp_compiler.py @@ -1361,7 +1361,11 @@ async def __check_target_versions(self, *, show_message: bool) -> None: raise CompilationError("Invalid target version") async def __detect_target_versions( - self, compilation_units: List[CompilationUnit], *, show_message: bool + self, + graph: nx.DiGraph, + compilation_units: List[CompilationUnit], + *, + show_message: bool, ) -> Tuple[List[SolidityVersion], List[CompilationUnit], List[str]]: min_version = self.__config.min_solidity_version max_version = self.__config.max_solidity_version @@ -1377,13 +1381,15 @@ async def __detect_target_versions( compilation_unit.subproject ].target_version ) + files_str = "\n".join( + f" {su}: {graph.nodes[su]['versions']}" + for su in compilation_unit.source_unit_names + ) if target_version is not None: if target_version not in compilation_unit.versions: message = ( f"Unable to compile the following files with solc version `{target_version}` set in config:\n" - + "\n".join( - path_to_uri(path) for path in compilation_unit.files - ) + + files_str ) await self.__server.log_message(message, MessageType.WARNING) @@ -1405,9 +1411,7 @@ async def __detect_target_versions( if len(matching_versions) == 0: message = ( f"Unable to find a matching version of Solidity for the following files:\n" - + "\n".join( - path_to_uri(path) for path in compilation_unit.files - ) + + files_str ) await self.__server.log_message(message, MessageType.WARNING) @@ -1429,9 +1433,7 @@ async def __detect_target_versions( except StopIteration: message = ( f"The maximum supported version of Solidity is {max_version}, unable to compile the following files:\n" - + "\n".join( - path_to_uri(path) for path in compilation_unit.files - ) + + files_str ) await self.__server.log_message(message, MessageType.WARNING) @@ -1447,9 +1449,7 @@ async def __detect_target_versions( if target_version < min_version: message = ( f"The minimum supported version of Solidity is {min_version}, unable to compile the following files:\n" - + "\n".join( - path_to_uri(path) for path in compilation_unit.files - ) + + files_str ) await self.__server.log_message(message, MessageType.WARNING) @@ -1564,7 +1564,9 @@ async def bytecode_compile( target_versions, skipped_compilation_units, skipped_reasons, - ) = await self.__detect_target_versions(compilation_units, show_message=True) + ) = await self.__detect_target_versions( + graph, compilation_units, show_message=True + ) skipped_source_units = {} for compilation_unit, reason in zip(skipped_compilation_units, skipped_reasons): @@ -1829,7 +1831,9 @@ async def __compile( target_versions, skipped_compilation_units, _, - ) = await self.__detect_target_versions(compilation_units, show_message=False) + ) = await self.__detect_target_versions( + graph, compilation_units, show_message=False + ) await self.__install_solc(target_versions) # files passed as files_to_compile and files importing them diff --git a/wake/mutators/__init__.py b/wake/mutators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/wake/mutators/api.py b/wake/mutators/api.py new file mode 100644 index 000000000..85b59f3d3 --- /dev/null +++ b/wake/mutators/api.py @@ -0,0 +1,90 @@ +from abc import ABCMeta +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path +from typing import List, Optional, TYPE_CHECKING + +from wake.core.visitor import Visitor + +if TYPE_CHECKING: + from wake.ir import IrAbc + + +class MutantStatus(Enum): + PENDING = auto() + KILLED = auto() + SURVIVED = auto() + TIMEOUT = auto() + COMPILE_ERROR = auto() + + +@dataclass(frozen=True, eq=True) +class Mutation: + """Immutable representation of a single mutation.""" + operator: str + file_path: Path + byte_start: int + byte_end: int + original: str + replacement: str + description: str + node_id: Optional[int] = None + status: MutantStatus = MutantStatus.PENDING + + @property + def id(self) -> str: + import hashlib + data = f"{self.file_path}:{self.byte_start}:{self.original}:{self.replacement}" + return hashlib.sha256(data.encode()).hexdigest()[:12] + + @property + def line_number(self) -> int: + """Convert byte offset to 1-indexed line number.""" + content = self.file_path.read_text() + return content[:self.byte_start].count('\n') + 1 + + def apply(self, source: bytes) -> bytes: + return source[:self.byte_start] + self.replacement.encode() + source[self.byte_end:] + + +class Mutator(Visitor, metaclass=ABCMeta): + """Base class for mutation operators using Wake's IR.""" + + name: str = "base" + description: str = "Base mutation operator" + + def __init__(self): + self._mutations: List[Mutation] = [] + self._current_file: Optional[Path] = None + + @property + def visit_mode(self) -> str: + return "paths" + + @property + def mutations(self) -> List[Mutation]: + return self._mutations + + def _add( + self, + node: "IrAbc", + original: str, + replacement: str, + description: Optional[str] = None, + ) -> None: + """Register a mutation from an IR node.""" + start, end = node.byte_location + + if description is None: + description = f"{original} → {replacement}" + + self._mutations.append(Mutation( + operator=self.name, + file_path=self._current_file, + byte_start=start, + byte_end=end, + original=original, + replacement=replacement, + description=description, + node_id=node.ast_node_id, + )) \ No newline at end of file diff --git a/wake/mutators/binary_operator_mutator.py b/wake/mutators/binary_operator_mutator.py new file mode 100644 index 000000000..38adf3b67 --- /dev/null +++ b/wake/mutators/binary_operator_mutator.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from typing import Dict, List, Tuple, Union + +from wake.mutators.api import Mutator +from wake.ir.expressions.binary_operation import BinaryOperation +from wake.ir.enums import BinaryOpOperator + + +class BinaryOperatorMutator(Mutator): + """Base class for binary operator replacement mutators.""" + + operator_map: Dict[BinaryOpOperator, Union[BinaryOpOperator, List[BinaryOpOperator]]] = {} + + def visit_binary_operation(self, node: BinaryOperation): + if node.operator not in self.operator_map: + return + + replacements = self.operator_map[node.operator] + if not isinstance(replacements, list): + replacements = [replacements] + + left = node.left_expression.source + right = node.right_expression.source + + for replacement_op in replacements: + self._add( + node=node, + original=node.source, + replacement=f"{left} {replacement_op.value} {right}", + ) \ No newline at end of file diff --git a/wake/mutators/literal_mutator.py b/wake/mutators/literal_mutator.py new file mode 100644 index 000000000..63d3ad79e --- /dev/null +++ b/wake/mutators/literal_mutator.py @@ -0,0 +1,31 @@ +from abc import abstractmethod +from typing import List + +from wake.mutators.api import Mutator +from wake.ir.expressions.literal import Literal +from wake.ir.enums import LiteralKind + + +class LiteralMutator(Mutator): + """Base class for literal replacement mutators.""" + + # Subclasses specify which literal kinds to target + target_kinds: List[LiteralKind] = [] + + def visit_literal(self, node: Literal): + if node.kind not in self.target_kinds: + return + + replacements = self.get_replacements(node) + + for replacement in replacements: + self._add( + node=node, + original=node.source, + replacement=replacement, + ) + + @abstractmethod + def get_replacements(self, node: Literal) -> List[str]: + """Return list of replacement values for this literal.""" + ... \ No newline at end of file diff --git a/wake_mutators/__init__.py b/wake_mutators/__init__.py new file mode 100644 index 000000000..c35b75315 --- /dev/null +++ b/wake_mutators/__init__.py @@ -0,0 +1,19 @@ +from .plus_for_minus_replacement import PlusForMinusReplacement +from .minus_for_plus_replacement import MinusForPlusReplacement +from .equality_for_inequality_replacement import EqualityForInequalityReplacement +from .inequality_for_equality_replacement import InequalityForEqualityReplacement +from .modifier_deletion import ModifierDeletion +from .greater_for_less_equal_to import GreaterForLessEqualTo +from .greater_or_equal_to_for_less import GreaterOrEqualToForLess +from .less_equal_to_for_greater import LessEqualToForGreater +from .less_for_greater_equal_to import LessForGreaterEqualTo +from .boolean_flip import BooleanFlipMutator +from .require_removal import RequireRemoval +from .and_for_or_replacement import AndForOrReplacement +from .or_for_and_replacement import OrForAndReplacement +from .receive_fallback_removal import ReceiveFallbackRemoval +from .revert_message_removal import RevertMessageRemoval +from .revert_error_removal import RevertErrorRemoval +from .payable_removal import PayableRemoval +from .public_external_replacement import PublicExternalReplacement +from .external_public_replacement import ExternalPublicReplacement diff --git a/wake_mutators/and_for_or_replacement.py b/wake_mutators/and_for_or_replacement.py new file mode 100644 index 000000000..74ecd5fc0 --- /dev/null +++ b/wake_mutators/and_for_or_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class AndForOrReplacement(BinaryOperatorMutator): + """Replace && with ||""" + + name = "and_for_or" + description = "Replace && with ||" + + operator_map = { + BinaryOpOperator.BOOLEAN_AND: BinaryOpOperator.BOOLEAN_OR, + } diff --git a/wake_mutators/boolean_flip.py b/wake_mutators/boolean_flip.py new file mode 100644 index 000000000..24aa06274 --- /dev/null +++ b/wake_mutators/boolean_flip.py @@ -0,0 +1,19 @@ +from wake.mutators.literal_mutator import LiteralMutator +from wake.ir.expressions.literal import Literal +from wake.ir.enums import LiteralKind + + +class BooleanFlipMutator(LiteralMutator): + """ Flip boolean literals (true, false) """ + + name = "boolean_literal" + description = "Flip boolean literals (true, false)" + + target_kinds = [LiteralKind.BOOL] + + def get_replacements(self, node: Literal) -> list[str]: + if node.value == "true": + return ["false"] + elif node.value == "false": + return ["true"] + return [] diff --git a/wake_mutators/emit_event_deletion.py b/wake_mutators/emit_event_deletion.py new file mode 100644 index 000000000..8abeb6639 --- /dev/null +++ b/wake_mutators/emit_event_deletion.py @@ -0,0 +1,17 @@ +from wake.mutators.api import Mutator +from wake.ir.statements.emit_statement import EmitStatement + + +class EmitEventDeletion(Mutator): + """Remove emit statements.""" + + name = "emit_event_deletion" + description = "Remove emit statements" + + def visit_emit_statement(self, node: EmitStatement): + self._add( + node=node, + original=node.source, + replacement="true", + description="Remove emit statement", + ) diff --git a/wake_mutators/equality_for_inequality_replacement.py b/wake_mutators/equality_for_inequality_replacement.py new file mode 100644 index 000000000..e6ec84093 --- /dev/null +++ b/wake_mutators/equality_for_inequality_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class EqualityForInequalityReplacement(BinaryOperatorMutator): + """ Replace equality (==) with inequality (!=) """ + + name = "equality_for_inequality" + description = "Replace equality (==) with inequality (!=)" + + operator_map = { + BinaryOpOperator.EQ: BinaryOpOperator.NEQ, + } \ No newline at end of file diff --git a/wake_mutators/external_public_replacement.py b/wake_mutators/external_public_replacement.py new file mode 100644 index 000000000..1911b20f3 --- /dev/null +++ b/wake_mutators/external_public_replacement.py @@ -0,0 +1,46 @@ +import re + +from wake.mutators.api import Mutator +from wake.ir.declarations.function_definition import FunctionDefinition +from wake.ir.enums import FunctionKind +from wake.ir.enums import Visibility + + +class ExternalPublicReplacement(Mutator): + """Replace external with public.""" + + name = "external_public_replacement" + description = "Replace external with public" + + def visit_function_definition(self, node: FunctionDefinition): + + if node.kind in [ + FunctionKind.CONSTRUCTOR, + FunctionKind.RECEIVE, # must be external + FunctionKind.FALLBACK, # must be external + ]: + return + + if node.visibility != Visibility.EXTERNAL: + return + + source = node.source + header_end = source.find("{") + if header_end == -1: + header_end = source.find(";") + if header_end == -1: + header_end = len(source) + + header = source[:header_end] + body = source[header_end:] + new_header = re.sub(r"\bexternal\b", "public", header, count=1) + replacement = new_header + body + if replacement == source: + return + + self._add( + node=node, + original=source, + replacement=replacement, + description="Replace external with public", + ) diff --git a/wake_mutators/greater_for_less_equal_to.py b/wake_mutators/greater_for_less_equal_to.py new file mode 100644 index 000000000..837521f7f --- /dev/null +++ b/wake_mutators/greater_for_less_equal_to.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class GreaterForLessEqualTo(BinaryOperatorMutator): + """Replace greater than (>) with less than or equal to (<=)""" + + name = "greater_for_less_equal_to" + description = "Replace greater than (>) with less than or equal to (<=)" + + operator_map = { + BinaryOpOperator.GT: BinaryOpOperator.LTE, + } diff --git a/wake_mutators/greater_or_equal_to_for_less.py b/wake_mutators/greater_or_equal_to_for_less.py new file mode 100644 index 000000000..1b0475bb0 --- /dev/null +++ b/wake_mutators/greater_or_equal_to_for_less.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class GreaterOrEqualToForLess(BinaryOperatorMutator): + """Replace greater than or equal to (>=) with less than (<)""" + + name = "greater_or_equal_to_for_less" + description = "Replace greater than or equal to (>=) with less than (<)" + + operator_map = { + BinaryOpOperator.GTE: BinaryOpOperator.LT, + } diff --git a/wake_mutators/inequality_for_equality_replacement.py b/wake_mutators/inequality_for_equality_replacement.py new file mode 100644 index 000000000..4c82850e2 --- /dev/null +++ b/wake_mutators/inequality_for_equality_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class InequalityForEqualityReplacement(BinaryOperatorMutator): + """Replace inequality (!=) with equality (==)""" + + name = "inequality_for_equality" + description = "Replace inequality (!=) with equality (==)" + + operator_map = { + BinaryOpOperator.NEQ: BinaryOpOperator.EQ, + } \ No newline at end of file diff --git a/wake_mutators/less_equal_to_for_greater.py b/wake_mutators/less_equal_to_for_greater.py new file mode 100644 index 000000000..03877cc25 --- /dev/null +++ b/wake_mutators/less_equal_to_for_greater.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class LessEqualToForGreater(BinaryOperatorMutator): + """Replace less than or equal to (<=) with greater than (>)""" + + name = "less_equal_to_for_greater" + description = "Replace less than or equal to (<=) with greater than (>)" + + operator_map = { + BinaryOpOperator.LTE: BinaryOpOperator.GT, + } diff --git a/wake_mutators/less_for_greater_equal_to.py b/wake_mutators/less_for_greater_equal_to.py new file mode 100644 index 000000000..57be3c458 --- /dev/null +++ b/wake_mutators/less_for_greater_equal_to.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class LessForGreaterEqualTo(BinaryOperatorMutator): + """Replace less than (<) with greater than or equal to (>=)""" + + name = "less_for_greater_equal_to" + description = "Replace less than (<) with greater than or equal to (>=)" + + operator_map = { + BinaryOpOperator.LT: BinaryOpOperator.GTE, + } diff --git a/wake_mutators/minus_for_plus_replacement.py b/wake_mutators/minus_for_plus_replacement.py new file mode 100644 index 000000000..be008ba6a --- /dev/null +++ b/wake_mutators/minus_for_plus_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class MinusForPlusReplacement(BinaryOperatorMutator): + """Replace subtraction (-) with addition (+)""" + + name = "minus_for_plus" + description = "Replace subtraction (-) with addition (+)" + + operator_map = { + BinaryOpOperator.MINUS: BinaryOpOperator.PLUS, + } \ No newline at end of file diff --git a/wake_mutators/modifier_deletion.py b/wake_mutators/modifier_deletion.py new file mode 100644 index 000000000..f3916a6e1 --- /dev/null +++ b/wake_mutators/modifier_deletion.py @@ -0,0 +1,23 @@ +from wake.mutators.api import Mutator +from wake.ir.declarations.function_definition import FunctionDefinition + + +class ModifierDeletion(Mutator): + """Remove function modifiers one at a time""" + + name = "modifier_deletion" + description = "Delete function modifiers (e.g., onlyOwner)" + + def visit_function_definition(self, node: FunctionDefinition): + if not node.modifiers: + return + + for modifier in node.modifiers: + mod_source = modifier.source + + self._add( + node=modifier, + original=mod_source, + replacement="", + description=self.description, + ) \ No newline at end of file diff --git a/wake_mutators/or_for_and_replacement.py b/wake_mutators/or_for_and_replacement.py new file mode 100644 index 000000000..3f23d4379 --- /dev/null +++ b/wake_mutators/or_for_and_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class OrForAndReplacement(BinaryOperatorMutator): + """Replace || with &&""" + + name = "or_for_and" + description = "Replace || with &&" + + operator_map = { + BinaryOpOperator.BOOLEAN_OR: BinaryOpOperator.BOOLEAN_AND, + } \ No newline at end of file diff --git a/wake_mutators/payable_removal.py b/wake_mutators/payable_removal.py new file mode 100644 index 000000000..475c23db4 --- /dev/null +++ b/wake_mutators/payable_removal.py @@ -0,0 +1,42 @@ +import re + +from wake.mutators.api import Mutator +from wake.ir.declarations.function_definition import FunctionDefinition +from wake.ir.enums import StateMutability +from wake.ir.enums import FunctionKind + + +class PayableRemoval(Mutator): + """Remove payable mutability from function definitions.""" + + name = "payable_removal" + description = "Remove payable mutability" + + def visit_function_definition(self, node: FunctionDefinition): + if node.state_mutability != StateMutability.PAYABLE: + return + + if node.kind == FunctionKind.RECEIVE: + # Receive function must be payable for compilation + return + + source = node.source + header_end = source.find("{") + if header_end == -1: + header_end = source.find(";") + if header_end == -1: + header_end = len(source) + + header = source[:header_end] + body = source[header_end:] + new_header = re.sub(r"\bpayable\b\s*", "", header, count=1) + replacement = new_header + body + if replacement == source: + return + + self._add( + node=node, + original=source, + replacement=replacement, + description="Remove payable", + ) diff --git a/wake_mutators/plus_for_minus_replacement.py b/wake_mutators/plus_for_minus_replacement.py new file mode 100644 index 000000000..496d0c229 --- /dev/null +++ b/wake_mutators/plus_for_minus_replacement.py @@ -0,0 +1,13 @@ +from wake.mutators.binary_operator_mutator import BinaryOperatorMutator +from wake.ir.enums import BinaryOpOperator + + +class PlusForMinusReplacement(BinaryOperatorMutator): + """Replace addition (+) with subtraction (-)""" + + name = "plus_for_minus" + description = "Replace addition (+) with subtraction (-)" + + operator_map = { + BinaryOpOperator.PLUS: BinaryOpOperator.MINUS, + } \ No newline at end of file diff --git a/wake_mutators/public_external_replacement.py b/wake_mutators/public_external_replacement.py new file mode 100644 index 000000000..7e16e7899 --- /dev/null +++ b/wake_mutators/public_external_replacement.py @@ -0,0 +1,45 @@ +import re + +from wake.mutators.api import Mutator +from wake.ir.declarations.function_definition import FunctionDefinition +from wake.ir.enums import FunctionKind +from wake.ir.enums import Visibility + + +class PublicExternalReplacement(Mutator): + """Replace public with external.""" + + name = "public_external_replacement" + description = "Replace public with external" + + def visit_function_definition(self, node: FunctionDefinition): + if node.kind in [ + FunctionKind.CONSTRUCTOR, + FunctionKind.RECEIVE, # must be external + FunctionKind.FALLBACK, # must be external + ]: + return + + if node.visibility != Visibility.PUBLIC: + return + + source = node.source + header_end = source.find("{") + if header_end == -1: + header_end = source.find(";") + if header_end == -1: + header_end = len(source) + + header = source[:header_end] + body = source[header_end:] + new_header = re.sub(r"\bpublic\b", "external", header, count=1) + replacement = new_header + body + if replacement == source: + return + + self._add( + node=node, + original=source, + replacement=replacement, + description="Replace public with external", + ) diff --git a/wake_mutators/receive_fallback_removal.py b/wake_mutators/receive_fallback_removal.py new file mode 100644 index 000000000..c5a3e76b6 --- /dev/null +++ b/wake_mutators/receive_fallback_removal.py @@ -0,0 +1,20 @@ +from wake.mutators.api import Mutator +from wake.ir.declarations.function_definition import FunctionDefinition +from wake.ir.enums import FunctionKind + + +class ReceiveFallbackRemoval(Mutator): + """Remove receive() and fallback() functions.""" + + name = "receive_fallback_removal" + description = "Remove receive() and fallback() functions" + + def visit_function_definition(self, node: FunctionDefinition): + if node.kind not in (FunctionKind.RECEIVE, FunctionKind.FALLBACK): + return + self._add( + node=node, + original=node.source, + replacement="", + description=f"Remove {node.kind.value} function", + ) diff --git a/wake_mutators/require_removal.py b/wake_mutators/require_removal.py new file mode 100644 index 000000000..ddc3775b2 --- /dev/null +++ b/wake_mutators/require_removal.py @@ -0,0 +1,24 @@ +from wake.mutators.api import Mutator +from wake.ir.statements.revert_statement import RevertStatement +from wake.ir.expressions.function_call import FunctionCall + +class RequireRemoval(Mutator): + """Remove require statements""" + + name = "require_removal" + description = "Remove require/assert statements" + + def visit_function_call(self, node: FunctionCall): + name = None + if hasattr(node, 'function_name'): + name = node.function_name + elif hasattr(node.expression, 'name'): + name = node.expression.name + + if name in ("require", "assert"): + self._add( + node=node, + original=node.source, + replacement="true", + description=f"Remove {name}", + ) \ No newline at end of file diff --git a/wake_mutators/revert_error_removal.py b/wake_mutators/revert_error_removal.py new file mode 100644 index 000000000..b06311a98 --- /dev/null +++ b/wake_mutators/revert_error_removal.py @@ -0,0 +1,18 @@ +from wake.mutators.api import Mutator +from wake.ir.statements.revert_statement import RevertStatement + + +class RevertErrorRemoval(Mutator): + """Replace revert Error(...) with revert().""" + + name = "revert_error_removal" + description = "Replace revert Error(...) with revert()" + + def visit_revert_statement(self, node: RevertStatement): + print(node.source) + self._add( + node=node, + original=node.source, + replacement="revert()", + description="Remove revert error", + ) diff --git a/wake_mutators/revert_message_removal.py b/wake_mutators/revert_message_removal.py new file mode 100644 index 000000000..2e52ad66a --- /dev/null +++ b/wake_mutators/revert_message_removal.py @@ -0,0 +1,19 @@ +from wake.mutators.api import Mutator +from wake.ir.expressions.function_call import FunctionCall +from wake.ir.enums import GlobalSymbol + + +class RevertMessageRemoval(Mutator): + """Replace revert("...") with revert().""" + + name = "revert_message_removal" + description = "Replace revert(message) with revert()" + + def visit_function_call(self, node: FunctionCall): + if node.function_called == GlobalSymbol.REVERT and len(node.arguments) != 0: + self._add( + node=node, + original=node.source, + replacement="revert()", + description="Remove revert message", + )