Skip to content

Commit 1a19df2

Browse files
authored
Merge pull request #20950 from paldepind/rust/ti-raw-pointer
Rust: Type inference for raw pointers
2 parents 6d301f2 + 27ddc81 commit 1a19df2

File tree

14 files changed

+4765
-72
lines changed

14 files changed

+4765
-72
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
category: minorAnalysis
3+
---
4+
* Improved type inference for raw pointers (`*const` and `*mut`). This includes type inference for the raw borrow operators (`&raw const` and `&raw mut`) and dereferencing of raw pointers.

rust/ql/lib/codeql/rust/frameworks/stdlib/Builtins.qll

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,20 @@ class RefMutType extends BuiltinType {
176176
override string getDisplayName() { result = "&mut" }
177177
}
178178

179-
/** The builtin pointer type `*const T`. */
180-
class PtrType extends BuiltinType {
181-
PtrType() { this.getName() = "Ptr" }
179+
/** A builtin raw pointer type `*const T` or `*mut T`. */
180+
abstract private class PtrTypeImpl extends BuiltinType { }
181+
182+
final class PtrType = PtrTypeImpl;
183+
184+
/** The builtin raw pointer type `*const T`. */
185+
class PtrConstType extends PtrTypeImpl {
186+
PtrConstType() { this.getName() = "PtrConst" }
182187

183188
override string getDisplayName() { result = "*const" }
184189
}
185190

186-
/** The builtin pointer type `*mut T`. */
187-
class PtrMutType extends BuiltinType {
191+
/** The builtin raw pointer type `*mut T`. */
192+
class PtrMutType extends PtrTypeImpl {
188193
PtrMutType() { this.getName() = "PtrMut" }
189194

190195
override string getDisplayName() { result = "*mut" }

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,8 +774,11 @@ private TypeItemNode resolveBuiltin(TypeRepr tr) {
774774
tr instanceof RefTypeRepr and
775775
result instanceof Builtins::RefType
776776
or
777-
tr instanceof PtrTypeRepr and
778-
result instanceof Builtins::PtrType
777+
tr.(PtrTypeRepr).isConst() and
778+
result instanceof Builtins::PtrConstType
779+
or
780+
tr.(PtrTypeRepr).isMut() and
781+
result instanceof Builtins::PtrMutType
779782
or
780783
result.(Builtins::TupleType).getArity() = tr.(TupleTypeRepr).getNumberOfFields()
781784
}

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,23 @@ class NeverType extends Type, TNeverType {
339339
override Location getLocation() { result instanceof EmptyLocation }
340340
}
341341

342-
class PtrType extends StructType {
343-
PtrType() { this.getStruct() instanceof Builtins::PtrType }
342+
abstract class PtrType extends StructType { }
344343

345-
override string toString() { result = "*" }
344+
pragma[nomagic]
345+
TypeParamTypeParameter getPtrTypeParameter() {
346+
result = any(PtrType t).getPositionalTypeParameter(0)
347+
}
346348

347-
override Location getLocation() { result instanceof EmptyLocation }
349+
class PtrMutType extends PtrType {
350+
PtrMutType() { this.getStruct() instanceof Builtins::PtrMutType }
351+
352+
override string toString() { result = "*mut" }
353+
}
354+
355+
class PtrConstType extends PtrType {
356+
PtrConstType() { this.getStruct() instanceof Builtins::PtrConstType }
357+
358+
override string toString() { result = "*const" }
348359
}
349360

350361
/**
@@ -377,11 +388,6 @@ class UnknownType extends Type, TUnknownType {
377388
override Location getLocation() { result instanceof EmptyLocation }
378389
}
379390

380-
pragma[nomagic]
381-
TypeParamTypeParameter getPtrTypeParameter() {
382-
result = any(PtrType t).getPositionalTypeParameter(0)
383-
}
384-
385391
/** A type parameter. */
386392
abstract class TypeParameter extends Type {
387393
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,10 @@ module CertainTypeInference {
424424
or
425425
result = inferLiteralType(n, path, true)
426426
or
427-
result = inferRefNodeType(n) and
427+
result = inferRefPatType(n) and
428+
path.isEmpty()
429+
or
430+
result = inferRefExprType(n) and
428431
path.isEmpty()
429432
or
430433
result = inferLogicalOperationType(n, path)
@@ -599,10 +602,14 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
599602
strictcount(Expr e | bodyReturns(n1, e)) = 1
600603
)
601604
or
602-
(
603-
n1 = n2.(RefExpr).getExpr() or
604-
n1 = n2.(RefPat).getPat()
605-
) and
605+
n2 =
606+
any(RefExpr re |
607+
n1 = re.getExpr() and
608+
prefix1.isEmpty() and
609+
prefix2 = TypePath::singleton(inferRefExprType(re).getPositionalTypeParameter(0))
610+
)
611+
or
612+
n1 = n2.(RefPat).getPat() and
606613
prefix1.isEmpty() and
607614
prefix2 = TypePath::singleton(getRefTypeParameter())
608615
or
@@ -702,9 +709,7 @@ private predicate lubCoercion(AstNode parent, AstNode child, TypePath prefix) {
702709
* of `n2` at `prefix2`, but type information should only propagate from `n1` to
703710
* `n2`.
704711
*/
705-
private predicate typeEqualityNonSymmetric(
706-
AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2
707-
) {
712+
private predicate typeEqualityAsymmetric(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
708713
lubCoercion(n2, n1, prefix2) and
709714
prefix1.isEmpty()
710715
or
@@ -716,6 +721,13 @@ private predicate typeEqualityNonSymmetric(
716721
not lubCoercion(mid, n1, _) and
717722
prefix1 = prefixMid.append(suffix)
718723
)
724+
or
725+
// When `n2` is `*n1` propagate type information from a raw pointer type
726+
// parameter at `n1`. The other direction is handled in
727+
// `inferDereferencedExprPtrType`.
728+
n1 = n2.(DerefExpr).getExpr() and
729+
prefix1 = TypePath::singleton(getPtrTypeParameter()) and
730+
prefix2.isEmpty()
719731
}
720732

721733
pragma[nomagic]
@@ -728,7 +740,7 @@ private Type inferTypeEquality(AstNode n, TypePath path) {
728740
or
729741
typeEquality(n2, prefix2, n, prefix1)
730742
or
731-
typeEqualityNonSymmetric(n2, prefix2, n, prefix1)
743+
typeEqualityAsymmetric(n2, prefix2, n, prefix1)
732744
)
733745
}
734746

@@ -2999,16 +3011,21 @@ private Type inferFieldExprType(AstNode n, TypePath path) {
29993011
)
30003012
}
30013013

3002-
/** Gets the root type of the reference node `ref`. */
3014+
/** Gets the root type of the reference expression `ref`. */
30033015
pragma[nomagic]
3004-
private Type inferRefNodeType(AstNode ref) {
3005-
(
3006-
ref = any(IdentPat ip | ip.isRef()).getName()
3016+
private Type inferRefExprType(RefExpr ref) {
3017+
if ref.isRaw()
3018+
then
3019+
ref.isMut() and result instanceof PtrMutType
30073020
or
3008-
ref instanceof RefExpr
3009-
or
3010-
ref instanceof RefPat
3011-
) and
3021+
ref.isConst() and result instanceof PtrConstType
3022+
else result instanceof RefType
3023+
}
3024+
3025+
/** Gets the root type of the reference node `ref`. */
3026+
pragma[nomagic]
3027+
private Type inferRefPatType(AstNode ref) {
3028+
(ref = any(IdentPat ip | ip.isRef()).getName() or ref instanceof RefPat) and
30123029
result instanceof RefType
30133030
}
30143031

@@ -3192,6 +3209,27 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
31923209
)
31933210
}
31943211

3212+
pragma[nomagic]
3213+
private Type getInferredDerefType(DerefExpr de, TypePath path) { result = inferType(de, path) }
3214+
3215+
pragma[nomagic]
3216+
private PtrType getInferredDerefExprPtrType(DerefExpr de) { result = inferType(de.getExpr()) }
3217+
3218+
/**
3219+
* Gets the inferred type of `n` at `path` when `n` occurs in a dereference
3220+
* expression `*n` and when `n` is known to have a raw pointer type.
3221+
*
3222+
* The other direction is handled in `typeEqualityAsymmetric`.
3223+
*/
3224+
private Type inferDereferencedExprPtrType(AstNode n, TypePath path) {
3225+
exists(DerefExpr de, PtrType type, TypePath suffix |
3226+
de.getExpr() = n and
3227+
type = getInferredDerefExprPtrType(de) and
3228+
result = getInferredDerefType(de, suffix) and
3229+
path = TypePath::cons(type.getPositionalTypeParameter(0), suffix)
3230+
)
3231+
}
3232+
31953233
/**
31963234
* A matching configuration for resolving types of struct patterns
31973235
* like `let Foo { bar } = ...`.
@@ -3593,6 +3631,8 @@ private module Cached {
35933631
or
35943632
result = inferIndexExprType(n, path)
35953633
or
3634+
result = inferDereferencedExprPtrType(n, path)
3635+
or
35963636
result = inferForLoopExprType(n, path)
35973637
or
35983638
result = inferDynamicCallExprType(n, path)

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,18 @@ class NeverTypeReprMention extends TypeMention, NeverTypeRepr {
556556
}
557557

558558
class PtrTypeReprMention extends TypeMention instanceof PtrTypeRepr {
559+
private PtrType resolveRootType() {
560+
super.isConst() and result instanceof PtrConstType
561+
or
562+
super.isMut() and result instanceof PtrMutType
563+
}
564+
559565
override Type resolveTypeAt(TypePath path) {
560-
path.isEmpty() and
561-
result instanceof PtrType
566+
path.isEmpty() and result = this.resolveRootType()
562567
or
563568
exists(TypePath suffix |
564569
result = super.getTypeRepr().(TypeMention).resolveTypeAt(suffix) and
565-
path = TypePath::cons(getPtrTypeParameter(), suffix)
570+
path = TypePath::cons(this.resolveRootType().getPositionalTypeParameter(0), suffix)
566571
)
567572
}
568573
}

rust/ql/test/library-tests/elements/builtintypes/BuiltinTypes.expected

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
| struct Array | |
2-
| struct Ptr | |
2+
| struct PtrConst | |
33
| struct PtrMut | |
44
| struct Ref | |
55
| struct RefMut | |
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
use std::ptr::null_mut;
2+
3+
fn raw_pointer_const_deref(x: *const i32) -> i32 {
4+
let _y = unsafe { *x }; // $ type=_y:i32
5+
0
6+
}
7+
8+
fn raw_pointer_mut_deref(x: *mut bool) -> i32 {
9+
let _y = unsafe { *x }; // $ type=_y:bool
10+
0
11+
}
12+
13+
fn raw_const_borrow() {
14+
let a: i64 = 10;
15+
let x = &raw const a; // $ type=x:TPtrConst.i64
16+
unsafe {
17+
let _y = *x; // $ type=_y:i64
18+
}
19+
}
20+
21+
fn raw_mut_borrow() {
22+
let mut a = 10i32;
23+
let x = &raw mut a; // $ type=x:TPtrMut.i32
24+
unsafe {
25+
let _y = *x; // $ type=_y:i32
26+
}
27+
}
28+
29+
fn raw_mut_write(cond: bool) {
30+
let a = 10i32;
31+
// The type of `ptr_written` must be inferred from the write below.
32+
let ptr_written = null_mut(); // $ target=null_mut type=ptr_written:TPtrMut.i32
33+
if cond {
34+
unsafe {
35+
// NOTE: This write is undefined behavior because `ptr_written` is a null pointer.
36+
*ptr_written = a;
37+
let _y = *ptr_written; // $ type=_y:i32
38+
}
39+
}
40+
}
41+
42+
fn raw_type_from_deref(cond: bool) {
43+
// The type of `ptr_read` must be inferred from the read below.
44+
let ptr_read = null_mut(); // $ target=null_mut type=ptr_read:TPtrMut.i64
45+
if cond {
46+
unsafe {
47+
// NOTE: This read is undefined behavior because `ptr_read` is a null pointer.
48+
let _y: i64 = *ptr_read;
49+
}
50+
}
51+
}
52+
53+
pub fn test() {
54+
raw_pointer_const_deref(&10); // $ target=raw_pointer_const_deref
55+
raw_pointer_mut_deref(&mut true); // $ target=raw_pointer_mut_deref
56+
raw_const_borrow(); // $ target=raw_const_borrow
57+
raw_mut_borrow(); // $ target=raw_mut_borrow
58+
raw_mut_write(false); // $ target=raw_mut_write
59+
raw_type_from_deref(false); // $ target=raw_type_from_deref
60+
}

0 commit comments

Comments
 (0)