Skip to content

Commit 2130428

Browse files
committed
Refactor return type extraction in block arms for improved type checking consistency.
1 parent eeca5cd commit 2130428

File tree

4 files changed

+158
-32
lines changed

4 files changed

+158
-32
lines changed

examples/type_checking.ks

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ fn extract_header(ctx: *xdp_md) -> *PacketHeader {
7272
@helper
7373
fn classify_protocol(proto: u8) -> ProtocolType {
7474
// Type checker validates enum constant access
75-
match (proto) {
75+
return match (proto) {
7676
6: TCP,
7777
17: UDP,
7878
1: ICMP,
@@ -99,7 +99,7 @@ fn make_decision(header: PacketHeader) -> FilterDecision {
9999
// Type checker validates function call signatures
100100
var proto_type = classify_protocol(header.protocol)
101101

102-
match (proto_type) {
102+
return match (proto_type) {
103103
TCP: {
104104
// Type checker validates field access on struct types
105105
if (header.length > 1500) {
@@ -128,7 +128,7 @@ fn make_decision(header: PacketHeader) -> FilterDecision {
128128
var decision = make_decision(*packet_header)
129129

130130
// Type checker validates match expressions and enum types
131-
match (decision) {
131+
return match (decision) {
132132
Allow: XDP_PASS,
133133
Block: XDP_DROP,
134134
Log: {

src/ir_generator.ml

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,32 +1323,55 @@ and lower_statement ctx stmt =
13231323
let ret_val = lower_expression ctx expr in
13241324
IRReturnValue ret_val)
13251325
| Block stmts ->
1326-
(* For block arms, look for the return statement *)
1327-
let rec find_return_action = function
1328-
| [] -> failwith "Block arm must have a return statement"
1329-
| stmt :: rest ->
1330-
(match stmt.stmt_desc with
1331-
| Ast.Return (Some return_expr) ->
1332-
(match return_expr.expr_desc with
1333-
| Ast.Call (callee_expr, args) ->
1334-
(* Check if this is a simple function call that could be a tail call *)
1335-
(match callee_expr.expr_desc with
1336-
| Ast.Identifier name ->
1337-
let arg_vals = List.map (lower_expression ctx) args in
1338-
IRReturnCall (name, arg_vals)
1339-
| _ ->
1340-
(* Function pointer call - treat as regular return *)
1341-
let ret_val = lower_expression ctx return_expr in
1342-
IRReturnValue ret_val)
1343-
| Ast.TailCall (name, args) ->
1326+
(* For block arms, extract return action from the last statement *)
1327+
let rec extract_return_action_from_stmt stmt =
1328+
match stmt.stmt_desc with
1329+
| Ast.Return (Some return_expr) ->
1330+
(match return_expr.expr_desc with
1331+
| Ast.Call (callee_expr, args) ->
1332+
(* Check if this is a simple function call that could be a tail call *)
1333+
(match callee_expr.expr_desc with
1334+
| Ast.Identifier name ->
13441335
let arg_vals = List.map (lower_expression ctx) args in
1345-
IRReturnTailCall (name, arg_vals, 0)
1336+
IRReturnCall (name, arg_vals)
13461337
| _ ->
1338+
(* Function pointer call - treat as regular return *)
13471339
let ret_val = lower_expression ctx return_expr in
13481340
IRReturnValue ret_val)
1349-
| _ -> find_return_action rest)
1341+
| Ast.TailCall (name, args) ->
1342+
let arg_vals = List.map (lower_expression ctx) args in
1343+
IRReturnTailCall (name, arg_vals, 0)
1344+
| _ ->
1345+
let ret_val = lower_expression ctx return_expr in
1346+
IRReturnValue ret_val)
1347+
| Ast.ExprStmt expr ->
1348+
(* Handle implicit return from expression statement *)
1349+
(match expr.expr_desc with
1350+
| Ast.Call (callee_expr, args) ->
1351+
(match callee_expr.expr_desc with
1352+
| Ast.Identifier name ->
1353+
let arg_vals = List.map (lower_expression ctx) args in
1354+
IRReturnCall (name, arg_vals)
1355+
| _ ->
1356+
let ret_val = lower_expression ctx expr in
1357+
IRReturnValue ret_val)
1358+
| Ast.TailCall (name, args) ->
1359+
let arg_vals = List.map (lower_expression ctx) args in
1360+
IRReturnTailCall (name, arg_vals, 0)
1361+
| _ ->
1362+
let ret_val = lower_expression ctx expr in
1363+
IRReturnValue ret_val)
1364+
| Ast.If (_, then_stmts, Some _) ->
1365+
(* For if-else statements, we'll use the then branch action (both should be compatible) *)
1366+
extract_return_action_from_block then_stmts
1367+
| _ ->
1368+
failwith "Block arm must end with a return statement, expression, or if-else statement"
1369+
and extract_return_action_from_block stmts =
1370+
match List.rev stmts with
1371+
| last_stmt :: _ -> extract_return_action_from_stmt last_stmt
1372+
| [] -> failwith "Empty block in match arm"
13501373
in
1351-
find_return_action stmts
1374+
extract_return_action_from_block stmts
13521375
in
13531376

13541377
{ match_pattern = ir_pattern; return_action = return_action; arm_pos = arm.arm_pos }

src/type_checker.ml

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,31 @@ let is_truthy_type bpf_type =
618618
| Enum _ -> true (* enums: based on numeric value *)
619619
| _ -> false (* other types not allowed in boolean context *)
620620

621+
(** Helper function to extract return type from a block of statements *)
622+
let rec extract_block_return_type stmts arm_pos =
623+
let extract_type_from_stmt stmt =
624+
match stmt.tstmt_desc with
625+
| TReturn (Some return_expr) -> return_expr.texpr_type
626+
| TExprStmt expr -> expr.texpr_type
627+
| TIf (_, then_stmts, Some else_stmts) ->
628+
(* For if-else statements, both branches must return compatible types *)
629+
let then_type = extract_block_return_type then_stmts arm_pos in
630+
let else_type = extract_block_return_type else_stmts arm_pos in
631+
(match unify_types then_type else_type with
632+
| Some unified_type -> unified_type
633+
| None -> type_error ("If-else branches have incompatible types: " ^
634+
string_of_bpf_type then_type ^ " vs " ^
635+
string_of_bpf_type else_type) arm_pos)
636+
| TIf (_, _, None) ->
637+
(* If without else - this doesn't work as a return value *)
638+
type_error "If statement without else cannot be used as return value in match arm" arm_pos
639+
| _ ->
640+
type_error "Block arms must end with a return statement, expression, or if-else statement" arm_pos
641+
in
642+
match List.rev stmts with
643+
| last_stmt :: _ -> extract_type_from_stmt last_stmt
644+
| [] -> type_error "Empty block in match arm" arm_pos
645+
621646
(** Type check a user function call *)
622647
let rec type_check_user_function_call ctx name typed_args arg_types pos =
623648
try
@@ -1197,19 +1222,13 @@ and type_check_expression ctx expr =
11971222
let first_type = match first_arm.tarm_body with
11981223
| TSingleExpr expr -> expr.texpr_type
11991224
| TBlock stmts ->
1200-
(* For block arms, we need to find the return type from the last statement *)
1201-
match List.rev stmts with
1202-
| { tstmt_desc = TReturn (Some return_expr); _ } :: _ -> return_expr.texpr_type
1203-
| _ -> type_error "Block arms must end with a return statement" first_arm.tarm_pos
1225+
extract_block_return_type stmts first_arm.tarm_pos
12041226
in
12051227
List.iter (fun arm ->
12061228
let arm_type = match arm.tarm_body with
12071229
| TSingleExpr expr -> expr.texpr_type
12081230
| TBlock stmts ->
1209-
(* For block arms, we need to find the return type from the last statement *)
1210-
match List.rev stmts with
1211-
| { tstmt_desc = TReturn (Some return_expr); _ } :: _ -> return_expr.texpr_type
1212-
| _ -> type_error "Block arms must end with a return statement" arm.tarm_pos
1231+
extract_block_return_type stmts arm.tarm_pos
12131232
in
12141233
match unify_types first_type arm_type with
12151234
| Some _ -> () (* Compatible *)

tests/test_match.ml

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,89 @@ let test_nested_match_structures () =
461461

462462
check bool "nested match should generate nested conditional structures" true has_conditional_structure
463463

464+
(** Test match arms with implicit returns from block expressions - bug fix test *)
465+
let test_match_block_implicit_returns () =
466+
let input = {|
467+
enum Decision {
468+
Accept = 0,
469+
Reject = 1,
470+
Review = 2
471+
}
472+
473+
fn process_value(value: u32) -> Decision {
474+
return match (value) {
475+
1: Accept,
476+
2: {
477+
if (value > 10) {
478+
Reject
479+
} else {
480+
Review
481+
}
482+
},
483+
3: {
484+
Review
485+
},
486+
default: Reject
487+
}
488+
}
489+
|} in
490+
491+
let ast = Parse.parse_string input in
492+
let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in
493+
494+
(* The main test: Type checking should succeed (this would fail before the bug fix) *)
495+
(try
496+
let (_typed_ast, _typed_functions) = Type_checker.type_check_and_annotate_ast ~symbol_table:(Some symbol_table) ast in
497+
check bool "type checking should succeed for match arms with implicit returns" true true
498+
with
499+
| Failure msg when msg = "Block arms must end with a return statement" ->
500+
fail "Bug regression: type checker still requires explicit returns in match arm blocks"
501+
| Failure msg ->
502+
failwith ("Type checking failed with different error: " ^ msg)
503+
| exn ->
504+
failwith ("Unexpected type checking error: " ^ (Printexc.to_string exn)));
505+
506+
(* Verify the structure was parsed correctly *)
507+
let func = match List.find (function
508+
| GlobalFunction f when f.func_name = "process_value" -> true
509+
| _ -> false) ast with
510+
| GlobalFunction f -> f
511+
| _ -> failwith "Expected process_value function"
512+
in
513+
514+
let return_stmt = List.hd func.func_body in
515+
let match_expr = match return_stmt.stmt_desc with
516+
| Return (Some expr) -> (match expr.expr_desc with
517+
| Match (_, arms) -> arms
518+
| _ -> failwith "Expected match expression")
519+
| _ -> failwith "Expected return statement with match"
520+
in
521+
522+
(* Verify we have the expected structure *)
523+
check int "should have 4 match arms" 4 (List.length match_expr);
524+
525+
(* Verify the second arm has a block with if-else (implicit return) *)
526+
let second_arm = List.nth match_expr 1 in
527+
(match second_arm.arm_body with
528+
| Block stmts ->
529+
check bool "second arm should have statements" true (List.length stmts > 0);
530+
(* Verify it's an if statement (implicit return, no explicit return needed) *)
531+
(match (List.hd stmts).stmt_desc with
532+
| If (_, _, Some _) -> () (* if-else statement - good *)
533+
| _ -> failwith "Expected if-else statement in second arm")
534+
| _ -> failwith "Expected block in second arm");
535+
536+
(* Verify the third arm has a block with expression (implicit return) *)
537+
let third_arm = List.nth match_expr 2 in
538+
(match third_arm.arm_body with
539+
| Block stmts ->
540+
check bool "third arm should have statements" true (List.length stmts > 0);
541+
(* Verify it's an expression statement (implicit return) *)
542+
(match (List.hd stmts).stmt_desc with
543+
| ExprStmt _ -> () (* expression statement - good *)
544+
| _ -> failwith "Expected expression statement in third arm")
545+
| _ -> failwith "Expected block in third arm")
546+
464547
let suite = [
465548
"test_basic_match_parsing", `Quick, test_basic_match_parsing;
466549
"test_match_with_enums", `Quick, test_match_with_enums;
@@ -471,6 +554,7 @@ let suite = [
471554
"test_match_conditional_control_flow", `Quick, test_match_conditional_control_flow;
472555
"test_match_no_premature_execution", `Quick, test_match_no_premature_execution;
473556
"test_nested_match_structures", `Quick, test_nested_match_structures;
557+
"test_match_block_implicit_returns", `Quick, test_match_block_implicit_returns;
474558
]
475559

476560
let () = run "Match Construct Tests" [

0 commit comments

Comments
 (0)