Skip to content

Commit eeca5cd

Browse files
committed
Enhance type resolution in IR generation by adding support for type aliases and enums in the context of AST to IR type conversion. This includes preserving alias names and properly resolving struct and enum definitions from the symbol table.
1 parent c66322e commit eeca5cd

File tree

3 files changed

+109
-4
lines changed

3 files changed

+109
-4
lines changed

src/ir.ml

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,30 @@ let rec ast_type_to_ir_type = function
686686
(* Helper function that preserves type aliases when converting AST types to IR types *)
687687
let rec ast_type_to_ir_type_with_context symbol_table ast_type =
688688
match ast_type with
689-
| UserType name | Struct name ->
689+
| UserType name ->
690+
(* Check if this is a type alias or struct by looking up the symbol *)
691+
(match Symbol_table.lookup_symbol symbol_table name with
692+
| Some symbol ->
693+
(match symbol.kind with
694+
| Symbol_table.TypeDef (Ast.TypeAlias (_, underlying_type)) ->
695+
(* Create IRTypeAlias to preserve the alias name *)
696+
IRTypeAlias (name, ast_type_to_ir_type underlying_type)
697+
| Symbol_table.TypeDef (Ast.StructDef (_, fields, kernel_defined)) ->
698+
(* Resolve struct fields properly with type aliases preserved *)
699+
let ir_fields = List.map (fun (field_name, field_type) ->
700+
(field_name, ast_type_to_ir_type_with_context symbol_table field_type)
701+
) fields in
702+
IRStruct (name, ir_fields, kernel_defined)
703+
| Symbol_table.TypeDef (Ast.EnumDef (_, values, kernel_defined)) ->
704+
let ir_values = List.map (fun (enum_name, opt_value) ->
705+
(enum_name, Option.value ~default:0 opt_value)
706+
) values in
707+
IREnum (name, ir_values, kernel_defined)
708+
| _ -> ast_type_to_ir_type ast_type)
709+
| None ->
710+
(* Fallback to regular conversion *)
711+
ast_type_to_ir_type ast_type)
712+
| Struct name ->
690713
(* Check if this is a type alias or struct by looking up the symbol *)
691714
(match Symbol_table.lookup_symbol symbol_table name with
692715
| Some symbol ->
@@ -717,6 +740,18 @@ let rec ast_type_to_ir_type_with_context symbol_table ast_type =
717740
(* Recursively handle array element types with context *)
718741
let bounds = make_bounds_info ~min_size:size ~max_size:size () in
719742
IRArray (ast_type_to_ir_type_with_context symbol_table elem_type, size, bounds)
743+
| Enum name ->
744+
(* Check if this enum is defined in the symbol table *)
745+
(match Symbol_table.lookup_symbol symbol_table name with
746+
| Some symbol ->
747+
(match symbol.kind with
748+
| Symbol_table.TypeDef (Ast.EnumDef (_, values, kernel_defined)) ->
749+
let ir_values = List.map (fun (enum_name, opt_value) ->
750+
(enum_name, Option.value ~default:0 opt_value)
751+
) values in
752+
IREnum (name, ir_values, kernel_defined)
753+
| _ -> ast_type_to_ir_type ast_type)
754+
| None -> ast_type_to_ir_type ast_type)
720755
| _ -> ast_type_to_ir_type ast_type
721756

722757
let ast_map_type_to_ir_map_type = function

src/ir_generator.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -990,9 +990,9 @@ and resolve_type_alias ctx reg ast_type =
990990
Hashtbl.replace ctx.register_aliases reg (alias_name, underlying_ir_type);
991991
(* Create IRTypeAlias to preserve the alias name *)
992992
IRTypeAlias (alias_name, underlying_ir_type)
993-
| _ -> ast_type_to_ir_type ast_type)
994-
| None -> ast_type_to_ir_type ast_type)
995-
| _ -> ast_type_to_ir_type ast_type
993+
| _ -> ast_type_to_ir_type_with_context ctx.symbol_table ast_type)
994+
| None -> ast_type_to_ir_type_with_context ctx.symbol_table ast_type)
995+
| _ -> ast_type_to_ir_type_with_context ctx.symbol_table ast_type
996996

997997
(** Helper function to calculate stack usage for a type *)
998998
and calculate_stack_usage = function

tests/test_type_checker.ml

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,75 @@ let test_tail_call_cross_program_type_restriction _ =
12211221
true (contains_substr msg "incompatible program type")
12221222
| _ -> failwith "Expected TypeError for cross-program-type tail call")
12231223

1224+
(** Test map index type resolution bug fix - structs, enums, and type aliases as map keys *)
1225+
let test_map_index_type_resolution_bug_fix _ =
1226+
let source_code = {|
1227+
// Type alias
1228+
type IpAddress = u32
1229+
type Counter = u64
1230+
1231+
// Enum type
1232+
enum Protocol {
1233+
TCP = 6,
1234+
UDP = 17,
1235+
ICMP = 1
1236+
}
1237+
1238+
// Struct type
1239+
struct PacketInfo {
1240+
src_ip: IpAddress,
1241+
dst_ip: IpAddress,
1242+
protocol: u8
1243+
}
1244+
1245+
// Maps using different key types
1246+
map<IpAddress, Counter> connection_count : HashMap(1024) // Type alias key
1247+
map<Protocol, Counter> protocol_stats : PercpuArray(32) // Enum key
1248+
map<PacketInfo, u32> packet_filter : LruHash(512) // Struct key
1249+
1250+
@helper
1251+
fn test_indexing() -> u32 {
1252+
// Create test values
1253+
var ip: IpAddress = 0xC0A80001
1254+
var proto = TCP
1255+
var info = PacketInfo { src_ip: ip, dst_ip: ip, protocol: 6 }
1256+
1257+
// These should all work without "Array index must be integer type" error
1258+
var count1 = connection_count[ip] // Type alias as key
1259+
var count2 = protocol_stats[proto] // Enum as key
1260+
var result = packet_filter[info] // Struct as key
1261+
1262+
if (count1 != none && count2 != none && result != none) {
1263+
return count1 + count2 + result
1264+
} else {
1265+
return 0
1266+
}
1267+
}
1268+
1269+
@xdp fn packet_handler(ctx: *xdp_md) -> xdp_action {
1270+
return XDP_PASS
1271+
}
1272+
|} in
1273+
1274+
try
1275+
let ast = parse_string source_code in
1276+
let symbol_table = Test_utils.Helpers.create_test_symbol_table ast in
1277+
let _typed_ast = type_check_and_annotate_ast ~symbol_table:(Some symbol_table) ast in
1278+
1279+
(* If we reach here, type checking succeeded *)
1280+
check bool "map index type resolution works for structs, enums, and type aliases" true true
1281+
with
1282+
| Type_error (msg, _) when String.contains msg 'A' && String.contains msg 'r' && String.contains msg 'i' ->
1283+
(* If we get "Array index must be integer type" error, the test fails *)
1284+
fail ("Bug regression - map indexing should work with user types: " ^ msg)
1285+
| Type_error (msg, _) ->
1286+
(* Other type errors might be valid (e.g., map key type mismatches) *)
1287+
fail ("Unexpected type error: " ^ msg)
1288+
| Parse_error (msg, _) ->
1289+
fail ("Parse error: " ^ msg)
1290+
| e ->
1291+
fail ("Unexpected error: " ^ Printexc.to_string e)
1292+
12241293
let type_checker_tests = [
12251294
"type_unification", `Quick, test_type_unification;
12261295
"basic_type_inference", `Quick, test_basic_type_inference;
@@ -1258,6 +1327,7 @@ let type_checker_tests = [
12581327
"kernel_to_kernel_function_calls", `Quick, test_kernel_to_kernel_function_calls;
12591328
"function_call_user_type_resolution", `Quick, test_function_call_user_type_resolution;
12601329
"tail_call_cross_program_type_restriction", `Quick, test_tail_call_cross_program_type_restriction;
1330+
"map_index_type_resolution_bug_fix", `Quick, test_map_index_type_resolution_bug_fix;
12611331
]
12621332

12631333
let () =

0 commit comments

Comments
 (0)