@@ -424,7 +424,10 @@ module CertainTypeInference {
424424 or
425425 result = inferLiteralType ( n , path , true )
426426 or
427- result = inferRefNodeType ( n ) and
427+ result = inferRefPatType ( n ) and
428+ path .isEmpty ( )
429+ or
430+ result = inferRefExprType ( n ) and
428431 path .isEmpty ( )
429432 or
430433 result = inferLogicalOperationType ( n , path )
@@ -599,10 +602,14 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
599602 strictcount ( Expr e | bodyReturns ( n1 , e ) ) = 1
600603 )
601604 or
602- (
603- n1 = n2 .( RefExpr ) .getExpr ( ) or
604- n1 = n2 .( RefPat ) .getPat ( )
605- ) and
605+ n2 =
606+ any ( RefExpr re |
607+ n1 = re .getExpr ( ) and
608+ prefix1 .isEmpty ( ) and
609+ prefix2 = TypePath:: singleton ( inferRefExprType ( re ) .getPositionalTypeParameter ( 0 ) )
610+ )
611+ or
612+ n1 = n2 .( RefPat ) .getPat ( ) and
606613 prefix1 .isEmpty ( ) and
607614 prefix2 = TypePath:: singleton ( getRefTypeParameter ( ) )
608615 or
@@ -702,9 +709,7 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
702709 * of `n2` at `prefix2`, but type information should only propagate from `n1` to
703710 * `n2`.
704711 */
705- private predicate typeEqualityNonSymmetric (
706- AstNode n1 , TypePath prefix1 , AstNode n2 , TypePath prefix2
707- ) {
712+ private predicate typeEqualityAsymmetric ( AstNode n1 , TypePath prefix1 , AstNode n2 , TypePath prefix2 ) {
708713 lubCoercion ( n2 , n1 , prefix2 ) and
709714 prefix1 .isEmpty ( )
710715 or
@@ -716,6 +721,13 @@ private predicate typeEqualityNonSymmetric(
716721 not lubCoercion ( mid , n1 , _) and
717722 prefix1 = prefixMid .append ( suffix )
718723 )
724+ or
725+ // When `n2` is `*n1` propagate type information from a raw pointer type
726+ // parameter at `n1`. The other direction is handled in
727+ // `inferDereferencedExprPtrType`.
728+ n1 = n2 .( DerefExpr ) .getExpr ( ) and
729+ prefix1 = TypePath:: singleton ( getPtrTypeParameter ( ) ) and
730+ prefix2 .isEmpty ( )
719731}
720732
721733pragma [ nomagic]
@@ -728,7 +740,7 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
728740 or
729741 typeEquality ( n2 , prefix2 , n , prefix1 )
730742 or
731- typeEqualityNonSymmetric ( n2 , prefix2 , n , prefix1 )
743+ typeEqualityAsymmetric ( n2 , prefix2 , n , prefix1 )
732744 )
733745}
734746
@@ -2999,16 +3011,21 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
29993011 )
30003012}
30013013
3002- /** Gets the root type of the reference node `ref`. */
3014+ /** Gets the root type of the reference expression `ref`. */
30033015pragma [ nomagic]
3004- private Type inferRefNodeType ( AstNode ref ) {
3005- (
3006- ref = any ( IdentPat ip | ip .isRef ( ) ) .getName ( )
3016+ private Type inferRefExprType ( RefExpr ref ) {
3017+ if ref .isRaw ( )
3018+ then
3019+ ref .isMut ( ) and result instanceof PtrMutType
30073020 or
3008- ref instanceof RefExpr
3009- or
3010- ref instanceof RefPat
3011- ) and
3021+ ref .isConst ( ) and result instanceof PtrConstType
3022+ else result instanceof RefType
3023+ }
3024+
3025+ /** Gets the root type of the reference node `ref`. */
3026+ pragma [ nomagic]
3027+ private Type inferRefPatType ( AstNode ref ) {
3028+ ( ref = any ( IdentPat ip | ip .isRef ( ) ) .getName ( ) or ref instanceof RefPat ) and
30123029 result instanceof RefType
30133030}
30143031
@@ -3192,6 +3209,27 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
31923209 )
31933210}
31943211
3212+ pragma [ nomagic]
3213+ private Type getInferredDerefType ( DerefExpr de , TypePath path ) { result = inferType ( de , path ) }
3214+
3215+ pragma [ nomagic]
3216+ private PtrType getInferredDerefExprPtrType ( DerefExpr de ) { result = inferType ( de .getExpr ( ) ) }
3217+
3218+ /**
3219+ * Gets the inferred type of `n` at `path` when `n` occurs in a dereference
3220+ * expression `*n` and when `n` is known to have a raw pointer type.
3221+ *
3222+ * The other direction is handled in `typeEqualityAsymmetric`.
3223+ */
3224+ private Type inferDereferencedExprPtrType ( AstNode n , TypePath path ) {
3225+ exists ( DerefExpr de , PtrType type , TypePath suffix |
3226+ de .getExpr ( ) = n and
3227+ type = getInferredDerefExprPtrType ( de ) and
3228+ result = getInferredDerefType ( de , suffix ) and
3229+ path = TypePath:: cons ( type .getPositionalTypeParameter ( 0 ) , suffix )
3230+ )
3231+ }
3232+
31953233/**
31963234 * A matching configuration for resolving types of struct patterns
31973235 * like `let Foo { bar } = ...`.
@@ -3593,6 +3631,8 @@ private module Cached {
35933631 or
35943632 result = inferIndexExprType ( n , path )
35953633 or
3634+ result = inferDereferencedExprPtrType ( n , path )
3635+ or
35963636 result = inferForLoopExprType ( n , path )
35973637 or
35983638 result = inferDynamicCallExprType ( n , path )
0 commit comments