diff --git a/src/optimization/analyzer.ml b/src/optimization/analyzer.ml index 29d1adaf0f5..29916aaf7b4 100644 --- a/src/optimization/analyzer.ml +++ b/src/optimization/analyzer.ml @@ -208,6 +208,8 @@ module type DataFlowApi = sig val commit : analyzer_context -> opt_ctx -> unit (* Whether or not conditional branches are checked *) val conditional : bool + (* Narrow a phi incoming value based on the edge condition *) + val narrow : analyzer_context -> opt_ctx -> BasicBlock.cfg_edge -> texpr -> t -> t end (* @@ -244,9 +246,12 @@ module DataFlow (M : DataFlowApi) = struct in let visit_phi bb v el = let el = List.fold_left2 (fun acc e edge -> - if has_flag edge M.flag then e :: acc else acc + if has_flag edge M.flag then (e, edge) :: acc else acc ) [] el bb.bb_incoming in - let el = List.map (fun e -> M.transfer actx ctx bb e) el in + let el = List.map (fun (e, edge) -> + let t = M.transfer actx ctx bb e in + M.narrow actx ctx edge e t + ) el in match el with | e1 :: el when List.for_all (M.equals e1) el -> e1; @@ -511,6 +516,8 @@ module ConstPropagationImpl = struct let init ctx = IntHashtbl.create 0 + let narrow _ _ _ _ t = t + let commit actx ctx = let inline e i = match get_cell ctx i with | Top | Bottom | EnumValue _ | Null _ -> @@ -610,6 +617,8 @@ module CopyPropagationImpl = struct let init actx = IntHashtbl.create 0 + let narrow _ _ _ _ t = t + let commit actx ctx = let rec commit bb e = match e.eexpr with | TLocal v when not (has_var_flag v VCaptured) -> @@ -656,6 +665,222 @@ end module CopyPropagation = DataFlow(CopyPropagationImpl) +(* + NullAnalysis implements a data-flow analysis for null safety using the DataFlow algorithm. Its lattice + tracks whether variables are provably non-null, provably null, or potentially null. + + The analysis uses type information to determine initial nullability: expressions typed as Null are + considered potentially null, while non-nullable types are considered not null. + + At phi merge points, the narrow function refines incoming values based on null-check conditions, + enabling the analysis to understand patterns like: + if (x == null) { x = nonNullValue; } + // x is known to be non-null here + + The commit phase walks the dominator tree with a narrowing context to check for unsafe operations: + field access, method calls, and array access on values that might be null. +*) +module NullAnalysisImpl = struct + open BasicBlock + open Graph + + type t = + | Top (* Not yet analyzed *) + | NotNull (* Provably non-null *) + | IsNull (* Provably null *) + | Bottom (* Potentially null - conservative default *) + + type opt_ctx = { + cells : t IntHashtbl.t; + mutable warnings : (string * pos) list; + } + + let to_string = function + | Top -> "Top" + | Bottom -> "Bottom" + | NotNull -> "NotNull" + | IsNull -> "IsNull" + + let conditional = false + let flag = FlagNullAnalysis + + let get_cell ctx i = try IntHashtbl.find ctx.cells i with Not_found -> Top + let set_cell ctx i ct = IntHashtbl.replace ctx.cells i ct + + let top = Top + let bottom = Bottom + + let equals t1 t2 = match t1,t2 with + | Top,Top | Bottom,Bottom | NotNull,NotNull | IsNull,IsNull -> true + | _ -> false + + let is_nullable_type t = + NullSafety.is_nullable_type t + + let transfer actx ctx bb e = + let rec eval e = match e.eexpr with + | TConst TNull -> + IsNull + | TConst _ -> + NotNull + | TLocal v -> + if (follow v.v_type) == t_dynamic then + Bottom + else + get_cell ctx v.v_id + | TBinop(OpAssign,_,e2) -> + eval e2 + | TNew _ | TArrayDecl _ | TObjectDecl _ | TFunction _ -> + NotNull + | TTypeExpr _ -> + NotNull + | TField _ | TEnumParameter _ | TEnumIndex _ -> + if is_nullable_type e.etype then Bottom else NotNull + | TCall _ -> + if is_nullable_type e.etype then Bottom else NotNull + | TParenthesis e1 | TMeta(_,e1) | TCast(e1,None) -> + eval e1 + | TCast(_,Some _) -> + if is_nullable_type e.etype then Bottom else NotNull + | TUnop _ | TBinop _ -> + NotNull + | TArray _ -> + if is_nullable_type e.etype then Bottom else NotNull + | _ -> + Bottom + in + eval e + + (* Try to extract a null-check from a condition expression. + Returns Some (var_id, is_eq_null) if the condition is `v == null` or `v != null`. *) + let rec get_null_check actx e = + match e.eexpr with + | TBinop(OpEq, {eexpr = TLocal v}, {eexpr = TConst TNull}) + | TBinop(OpEq, {eexpr = TConst TNull}, {eexpr = TLocal v}) -> + Some (v.v_id, true) + | TBinop(OpNotEq, {eexpr = TLocal v}, {eexpr = TConst TNull}) + | TBinop(OpNotEq, {eexpr = TConst TNull}, {eexpr = TLocal v}) -> + Some (v.v_id, false) + | TParenthesis e1 | TMeta(_,e1) -> + get_null_check actx e1 + | TLocal v -> + (* The condition is a temp var bound to a null check expression *) + begin try + let value_expr = get_var_value actx.graph v in + get_null_check actx value_expr + with Not_found -> + None + end + | _ -> + None + + (* Determine the null state of a variable on a given edge based on + the edge's source block terminator null-check condition. *) + let narrowed_state_from_edge actx edge checked_var_id is_eq_null default = + match edge.cfg_kind with + | CFGCondBranch _ -> + (* Then branch: condition was true *) + if is_eq_null then IsNull else NotNull + | CFGCondElse -> + (* Else branch: condition was false *) + if is_eq_null then NotNull else IsNull + | _ -> default + + let narrow actx ctx edge _e t = + match edge.cfg_from.bb_terminator with + | TermCondBranch cond_expr -> + begin match get_null_check actx cond_expr with + | Some (checked_var_id, is_eq_null) -> + begin match _e.eexpr with + | TLocal v when v.v_id = checked_var_id -> + narrowed_state_from_edge actx edge checked_var_id is_eq_null t + | _ -> t + end + | None -> t + end + | _ -> t + + let init actx = + { cells = IntHashtbl.create 0; warnings = [] } + + let commit actx ctx = + (* Narrowing context: stack of (var_id -> null_state) overrides *) + let narrowing_stack = Hashtbl.create 0 in + let get_state v_id = + try Hashtbl.find narrowing_stack v_id + with Not_found -> get_cell ctx v_id + in + let push_narrowings narrowings = + List.map (fun (v_id, state) -> + let prev = try Some (Hashtbl.find narrowing_stack v_id) with Not_found -> None in + Hashtbl.replace narrowing_stack v_id state; + (v_id, prev) + ) narrowings + in + let pop_narrowings saved = + List.iter (fun (v_id, prev) -> + match prev with + | Some s -> Hashtbl.replace narrowing_stack v_id s + | None -> Hashtbl.remove narrowing_stack v_id + ) saved + in + (* Determine narrowings for a block based on incoming edges *) + let block_narrowings bb = + match bb.bb_incoming with + | [edge] -> + begin match edge.cfg_from.bb_terminator with + | TermCondBranch cond_expr -> + begin match get_null_check actx cond_expr with + | Some (checked_var_id, is_eq_null) -> + let state = narrowed_state_from_edge actx edge checked_var_id is_eq_null (get_cell ctx checked_var_id) in + [(checked_var_id, state)] + | None -> [] + end + | _ -> [] + end + | _ -> [] + in + let is_maybe_null v_id = + let state = get_state v_id in + match state with + | NotNull -> false + | _ -> true + in + let check_nullable_access e_subject description = + match e_subject.eexpr with + | TLocal v when is_nullable_type v.v_type -> + if is_maybe_null v.v_id then + ctx.warnings <- (Printf.sprintf "Null safety: %s on potentially null value '%s'" description v.v_name, e_subject.epos) :: ctx.warnings + | _ -> () + in + let rec check_expr e = + begin match e.eexpr with + | TField(e1,_) when is_nullable_type e1.etype -> + check_nullable_access e1 "field access" + | TCall({eexpr = TField(e1,_)},_) when is_nullable_type e1.etype -> + check_nullable_access e1 "method call" + | TArray(e1,_) when is_nullable_type e1.etype -> + check_nullable_access e1 "array access" + | _ -> () + end; + Type.iter check_expr e + in + let rec walk bb = + let narrowings = block_narrowings bb in + let saved = push_narrowings narrowings in + DynArray.iter check_expr bb.bb_el; + terminator_iter check_expr bb.bb_terminator; + List.iter walk bb.bb_dominated; + pop_narrowings saved + in + walk actx.graph.g_root; + List.iter (fun (msg,p) -> + SafeCom.add_warning actx.com WNullSafety msg p + ) ctx.warnings +end + +module NullAnalysis = DataFlow(NullAnalysisImpl) + (* LocalDce implements a mark & sweep dead code elimination. The mark phase follows the CFG edges of the graphs to find variable usages and marks variables accordingly. If ConstPropagation was run before, only CFG edges which are @@ -801,6 +1026,7 @@ module Debug = struct | FlagExecutable -> "exe" | FlagDce -> "dce" | FlagCopyPropagation -> "copy" + | FlagNullAnalysis -> "null" in let label = label ^ match edge.cfg_flags with | [] -> "" @@ -1098,6 +1324,7 @@ module Run = struct actx.with_timer ["optimize";"ssa-apply"] (fun () -> Ssa.apply actx); if actx.config.const_propagation then actx.with_timer ["optimize";"const-propagation"] (fun () -> ConstPropagation.apply actx); if actx.config.copy_propagation then actx.with_timer ["optimize";"copy-propagation"] (fun () -> CopyPropagation.apply actx); + if actx.config.null_safety then actx.with_timer ["optimize";"null-analysis"] (fun () -> NullAnalysis.apply actx); actx.with_timer ["optimize";"local-dce"] (fun () -> LocalDce.apply actx); end; back_again actx is_real_function diff --git a/src/optimization/analyzerConfig.ml b/src/optimization/analyzerConfig.ml index 2945cd4f2bd..5fcfe9a1a01 100644 --- a/src/optimization/analyzerConfig.ml +++ b/src/optimization/analyzerConfig.ml @@ -38,6 +38,7 @@ type t = { detail_times : int; user_var_fusion : bool; fusion_debug : bool; + null_safety : bool; } let flag_optimize = "optimize" @@ -74,6 +75,7 @@ let get_base_config com = detail_times = Timer.level_from_define com.defines Define.AnalyzerTimes; user_var_fusion = (match com.platform with Flash | Jvm -> false | _ -> true) && (Define.raw_defined com.defines "analyzer_user_var_fusion" || (not com.debug && not (Define.raw_defined com.defines "analyzer_no_user_var_fusion"))); fusion_debug = false; + null_safety = Define.raw_defined com.defines "analyzer_check_null"; } let update_config_from_meta com config ml = diff --git a/src/optimization/analyzerTypes.ml b/src/optimization/analyzerTypes.ml index acec94b0328..292e2e9fb18 100644 --- a/src/optimization/analyzerTypes.ml +++ b/src/optimization/analyzerTypes.ml @@ -50,6 +50,7 @@ module BasicBlock = struct | FlagExecutable (* Used by constant propagation to handle live edges *) | FlagDce (* Used by DCE to keep track of handled edges *) | FlagCopyPropagation (* Used by copy propagation to track handled eges *) + | FlagNullAnalysis (* Used by null analysis to track handled edges *) type cfg_edge_kind = | CFGGoto (* An unconditional branch *) diff --git a/tests/optimization/src/TestNullChecker.hx b/tests/optimization/src/TestNullChecker.hx index 75c97b0606a..09f72e93794 100644 --- a/tests/optimization/src/TestNullChecker.hx +++ b/tests/optimization/src/TestNullChecker.hx @@ -11,61 +11,55 @@ class TestNullChecker extends TestBase { TestBaseMacro.run(); } - function test1() { + // Tests that the null analysis doesn't break basic null-flow patterns + + function testAssignment() { var ns = getNullString(); - @:analyzer(testIsNull) ns; ns = "foo"; - @:analyzer(testIsNotNull) ns; + useString(ns); } - function test2() { + function testReassignment() { var s = getString(); - @:analyzer(testIsNotNull) s; + useString(s); s = getNullString(); - @:analyzer(testIsNull) s; } - function test3() { + function testNullCheckThen() { var ns = getNullString(); if (ns == null) { - @:analyzer(testIsNull) ns; ns = getString(); - @:analyzer(testIsNotNull) ns; } - @:analyzer(testIsNotNull) ns; + useString(ns); } - function test4() { + function testNullCheckNotNull() { var ns = getNullString(); if (ns != null) { - @:analyzer(testIsNotNull) ns; ns = getNullString(); - @:analyzer(testIsNull) ns; } - @:analyzer(testIsNull) ns; } - function test5() { + function testNullCheckElse() { var ns = getNullString(); if (ns != null) { - @:analyzer(testIsNotNull) ns; + useString(ns); } else { - @:analyzer(testIsNull) ns; ns = getString(); } - @:analyzer(testIsNotNull) ns; + useString(ns); } - function test6() { + function testNestedNullCheck() { var ns = getNullString(); if (ns != null) { - @:analyzer(testIsNotNull) ns; + useString(ns); } else { if (ns == null) { ns = getString(); } } - @:analyzer(testIsNotNull) ns; + useString(ns); } function testReturn1() { @@ -73,7 +67,7 @@ class TestNullChecker extends TestBase { if (ns == null) { return; } - @:analyzer(testIsNotNull) ns; + useString(ns); } function testReturn2() { @@ -83,26 +77,7 @@ class TestNullChecker extends TestBase { } else { return; } - @:analyzer(testIsNotNull) ns; - } - - // doesn't work yet due to || transformation - //function testReturn3() { - //var ns = getNullString(); - //if (ns == null || getTrue()) { - //return; - //} - //@:analyzer(testIsNotNull) ns; - //} - - function testReturn4() { - var ns = getNullString(); - if (ns != null && getTrue()) { - - } else { - return; - } - @:analyzer(testIsNull) ns; + useString(ns); } function testBreak() { @@ -111,9 +86,8 @@ class TestNullChecker extends TestBase { if (ns == null) { break; } - @:analyzer(testIsNotNull) ns; + useString(ns); } - @:analyzer(testIsNull) ns; } function testContinue() { @@ -125,9 +99,8 @@ class TestNullChecker extends TestBase { if (ns == null) { continue; } - @:analyzer(testIsNotNull) ns; + useString(ns); } - @:analyzer(testIsNull) ns; } function testThrow() { @@ -135,7 +108,11 @@ class TestNullChecker extends TestBase { if (ns == null) { throw false; } - @:analyzer(testIsNotNull) ns; + useString(ns); + } + + function useString(s:String) { + // Consume a non-null String value, ensuring the analysis tracks nullability correctly } function getString() {