Skip to content

Commit 785025f

Browse files
committed
Rust: Type inference for raw pointers
1 parent 3e7a7d5 commit 785025f

File tree

7 files changed

+184
-115
lines changed

7 files changed

+184
-115
lines changed

rust/ql/lib/codeql/rust/frameworks/stdlib/Builtins.qll

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,18 @@ class RefMutType extends BuiltinType {
176176
override string getDisplayName() { result = "&mut" }
177177
}
178178

179-
/** The builtin pointer type `*const T`. */
180-
class PtrType extends BuiltinType {
181-
PtrType() { this.getName() = "Ptr" }
179+
/** A builtin raw pointer type `*const T` or `*mut T`. */
180+
abstract class PtrType extends BuiltinType { }
181+
182+
/** The builtin raw pointer type `*const T`. */
183+
class PtrConstType extends PtrType {
184+
PtrConstType() { this.getName() = "PtrConst" }
182185

183186
override string getDisplayName() { result = "*const" }
184187
}
185188

186-
/** The builtin pointer type `*mut T`. */
187-
class PtrMutType extends BuiltinType {
189+
/** The builtin raw pointer type `*mut T`. */
190+
class PtrMutType extends PtrType {
188191
PtrMutType() { this.getName() = "PtrMut" }
189192

190193
override string getDisplayName() { result = "*mut" }

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,27 @@ class NeverType extends Type, TNeverType {
339339
override Location getLocation() { result instanceof EmptyLocation }
340340
}
341341

342-
class PtrType extends StructType {
343-
PtrType() { this.getStruct() instanceof Builtins::PtrType }
342+
abstract class PtrType extends StructType {
343+
override Location getLocation() { result instanceof EmptyLocation }
344+
}
345+
346+
pragma[nomagic]
347+
TypeParamTypeParameter getPtrTypeParameter() {
348+
result = any(PtrType t).getPositionalTypeParameter(0)
349+
}
350+
351+
class PtrMutType extends PtrType {
352+
PtrMutType() { this.getStruct() instanceof Builtins::PtrMutType }
353+
354+
override string toString() { result = "*mut" }
344355

345-
override string toString() { result = "*" }
356+
override Location getLocation() { result instanceof EmptyLocation }
357+
}
358+
359+
class PtrConstType extends PtrType {
360+
PtrConstType() { this.getStruct() instanceof Builtins::PtrConstType }
361+
362+
override string toString() { result = "*const" }
346363

347364
override Location getLocation() { result instanceof EmptyLocation }
348365
}
@@ -377,11 +394,6 @@ class UnknownType extends Type, TUnknownType {
377394
override Location getLocation() { result instanceof EmptyLocation }
378395
}
379396

380-
pragma[nomagic]
381-
TypeParamTypeParameter getPtrTypeParameter() {
382-
result = any(PtrType t).getPositionalTypeParameter(0)
383-
}
384-
385397
/** A type parameter. */
386398
abstract class TypeParameter extends Type {
387399
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,10 @@ module CertainTypeInference {
431431
or
432432
result = inferLiteralType(n, path, true)
433433
or
434-
result = inferRefNodeType(n) and
434+
result = inferRefPatType(n) and
435+
path.isEmpty()
436+
or
437+
result = inferRefExprType(n) and
435438
path.isEmpty()
436439
or
437440
result = inferLogicalOperationType(n, path)
@@ -606,10 +609,14 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
606609
strictcount(Expr e | bodyReturns(n1, e)) = 1
607610
)
608611
or
609-
(
610-
n1 = n2.(RefExpr).getExpr() or
611-
n1 = n2.(RefPat).getPat()
612-
) and
612+
exists(RefExpr re |
613+
n2 = re and
614+
n1 = re.getExpr() and
615+
prefix1.isEmpty() and
616+
prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0))
617+
)
618+
or
619+
n1 = n2.(RefPat).getPat() and
613620
prefix1.isEmpty() and
614621
prefix2 = TypePath::singleton(getRefTypeParameter())
615622
or
@@ -709,9 +716,7 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
709716
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
710717
* `n2`.
711718
*/
712-
private predicate typeEqualityNonSymmetric(
713-
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
714-
) {
719+
private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
715720
lubCoercion(n2, n1, prefix2) and
716721
prefix1.isEmpty()
717722
or
@@ -723,6 +728,13 @@ private predicate typeEqualityNonSymmetric(
723728
not lubCoercion(mid, n1, _) and
724729
prefix1 = prefixMid.append(suffix)
725730
)
731+
or
732+
// When `n2` is `*n1` propagate type information from a raw pointer type
733+
// parameter at `n1`. The other direction is handled in
734+
// `inferDereferencedExprPtrType`.
735+
n1 = n2.(DerefExpr).getExpr() and
736+
prefix1 = TypePath::singleton(getPtrTypeParameter()) and
737+
prefix2.isEmpty()
726738
}
727739

728740
pragma[nomagic]
@@ -735,7 +747,7 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
735747
or
736748
typeEquality(n2, prefix2, n, prefix1)
737749
or
738-
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
750+
typeEqualityAsymmetric(n2, prefix2, n, prefix1)
739751
)
740752
}
741753

@@ -2952,16 +2964,21 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
29522964
)
29532965
}
29542966

2955-
/** Gets the root type of the reference node `ref`. */
2967+
/** Gets the root type of the reference expression `ref`. */
29562968
pragma[nomagic]
2957-
private Type inferRefNodeType(AstNode ref) {
2958-
(
2959-
ref = any(IdentPat ip | ip.isRef()).getName()
2960-
or
2961-
ref instanceof RefExpr
2969+
private Type inferRefExprType(RefExpr ref) {
2970+
if ref.isRaw()
2971+
then
2972+
ref.isMut() and result instanceof PtrMutType
29622973
or
2963-
ref instanceof RefPat
2964-
) and
2974+
ref.isConst() and result instanceof PtrConstType
2975+
else result instanceof RefType
2976+
}
2977+
2978+
/** Gets the root type of the reference node `ref`. */
2979+
pragma[nomagic]
2980+
private Type inferRefPatType(AstNode ref) {
2981+
(ref = any(IdentPat ip | ip.isRef()).getName() or ref instanceof RefPat) and
29652982
result instanceof RefType
29662983
}
29672984

@@ -3145,6 +3162,21 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
31453162
)
31463163
}
31473164

3165+
/**
3166+
* Gets the inferred type of `n` at `path` when `n` occurs in a dereference
3167+
* expression `*n` and when `n` is known to have a raw pointer type.
3168+
*
3169+
* The other direction is handled in `typeEqualityAsymmetric`.
3170+
*/
3171+
private Type inferDereferencedExprPtrType(AstNode n, TypePath path) {
3172+
exists(DerefExpr de, PtrType type, TypePath suffix |
3173+
de.getExpr() = n and
3174+
type = inferType(de.getExpr()) and
3175+
result = inferType(de, suffix) and
3176+
path = TypePath::cons(type.getPositionalTypeParameter(0), suffix)
3177+
)
3178+
}
3179+
31483180
/**
31493181
* A matching configuration for resolving types of struct patterns
31503182
* like `let Foo { bar } = ...`.
@@ -3544,6 +3576,8 @@ private module Cached {
35443576
or
35453577
result = inferIndexExprType(n, path)
35463578
or
3579+
result = inferDereferencedExprPtrType(n, path)
3580+
or
35473581
result = inferForLoopExprType(n, path)
35483582
or
35493583
result = inferDynamicCallExprType(n, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,18 @@ class NeverTypeReprMention extends TypeMention, NeverTypeRepr {
556556
}
557557

558558
class PtrTypeReprMention extends TypeMention instanceof PtrTypeRepr {
559+
private PtrType resolveRootType() {
560+
super.isConst() and result instanceof PtrConstType
561+
or
562+
super.isMut() and result instanceof PtrMutType
563+
}
564+
559565
override Type resolveTypeAt(TypePath path) {
560-
path.isEmpty() and
561-
result instanceof PtrType
566+
path.isEmpty() and result = this.resolveRootType()
562567
or
563568
exists(TypePath suffix |
564569
result = super.getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
565-
path = TypePath::cons(getPtrTypeParameter(), suffix)
570+
path = TypePath::cons(this.resolveRootType().getPositionalTypeParameter(0), suffix)
566571
)
567572
}
568573
}

rust/ql/test/library-tests/type-inference/raw_pointer.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,47 @@
11
use std::ptr::null_mut;
22

33
fn raw_pointer_const_deref(x: *const i32) -> i32 {
4-
let _y = unsafe { *x }; // $ MISSING: type=_y:i32
4+
let _y = unsafe { *x }; // $ type=_y:i32
55
0
66
}
77

88
fn raw_pointer_mut_deref(x: *mut bool) -> i32 {
9-
let _y = unsafe { *x }; // $ MISSING: type=_y:bool
9+
let _y = unsafe { *x }; // $ type=_y:bool
1010
0
1111
}
1212

1313
fn raw_const_borrow() {
1414
let a: i64 = 10;
15-
let x = &raw const a; // $ MISSING: type=x:TPtrConst.i64
15+
let x = &raw const a; // $ type=x:TPtrConst.i64
1616
unsafe {
17-
let _y = *x; // $ type=_y:i64 SPURIOUS: target=deref
17+
let _y = *x; // $ type=_y:i64
1818
}
1919
}
2020

2121
fn raw_mut_borrow() {
2222
let mut a = 10i32;
23-
let x = &raw mut a; // $ MISSING: type=x:TPtrMut.i32
23+
let x = &raw mut a; // $ type=x:TPtrMut.i32
2424
unsafe {
25-
let _y = *x; // $ type=_y:i32 SPURIOUS: target=deref
25+
let _y = *x; // $ type=_y:i32
2626
}
2727
}
2828

2929
fn raw_mut_write(cond: bool) {
3030
let a = 10i32;
3131
// The type of `x` must be inferred from the write below.
32-
let ptr_written = null_mut(); // $ target=null_mut MISSING: type=ptr_written:TPtrMut.i32
32+
let ptr_written = null_mut(); // $ target=null_mut type=ptr_written:TPtrMut.i32
3333
if cond {
3434
unsafe {
3535
// NOTE: This write is undefined behavior because `x` is a null pointer.
3636
*ptr_written = a;
37-
let _y = *ptr_written; // $ MISSING: type=_y:i32
37+
let _y = *ptr_written; // $ type=_y:i32
3838
}
3939
}
4040
}
4141

4242
fn raw_type_from_deref(cond: bool) {
4343
// The type of `x` must be inferred from the read below.
44-
let ptr_read = null_mut(); // $ target=null_mut MISSING: type=ptr_read:TPtrMut.i64
44+
let ptr_read = null_mut(); // $ target=null_mut type=ptr_read:TPtrMut.i64
4545
if cond {
4646
unsafe {
4747
// NOTE: This read is undefined behavior because `x` is a null pointer.

0 commit comments

Comments
 (0)