Skip to content

Commit 294c489

Browse files
committed
Rust: Handle x[y] expressions as *.index(y) calls in data flow
1 parent e72c8ac commit 294c489

File tree

17 files changed

+513
-341
lines changed

17 files changed

+513
-341
lines changed

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ final class DataFlowCall extends TDataFlowCall {
9090
* Holds if `arg` is an argument of `call` at the position `pos`.
9191
*/
9292
predicate isArgumentForCall(Expr arg, Call call, RustDataFlow::ArgumentPosition pos) {
93-
// TODO: Handle index expressions as calls in data flow.
94-
not call instanceof IndexExpr and
9593
arg = pos.getArgument(call)
9694
}
9795

@@ -261,6 +259,30 @@ private module Aliases {
261259
class LambdaCallKindAlias = LambdaCallKind;
262260
}
263261

262+
/**
263+
* Index assignments like `a[i] = rhs` are treated as `*a.index_mut(i) = rhs`,
264+
* so they should in principle be handled by `referenceAssignment`.
265+
*
266+
* However, this would require support for [generalized reverse flow][1], which
267+
* is not yet implemented, so instead we simulate reverse flow where it would
268+
* have applied via the model for `<_ as core::ops::index::IndexMut>::index_mut`.
269+
*
270+
* The same is the case for compound assignments like `a[i] += rhs`, which are
271+
* treated as `(*a.index_mut(i)).add_assign(rhs)`.
272+
*
273+
* [1]: https://github.com/github/codeql/pull/18109
274+
*/
275+
predicate indexAssignment(
276+
AssignmentOperation assignment, IndexExpr index, Node rhs, PostUpdateNode base, Content c
277+
) {
278+
assignment.getLhs() = index and
279+
rhs.asExpr() = assignment.getRhs() and
280+
base.getPreUpdateNode().asExpr() = index.getBase() and
281+
c instanceof ElementContent and
282+
// simulate that the flow summary applies
283+
not index.getResolvedTarget().fromSource()
284+
}
285+
264286
module RustDataFlow implements InputSig<Location> {
265287
private import Aliases
266288
private import codeql.rust.dataflow.DataFlow
@@ -360,6 +382,7 @@ module RustDataFlow implements InputSig<Location> {
360382
node instanceof ClosureParameterNode or
361383
node instanceof DerefBorrowNode or
362384
node instanceof DerefOutNode or
385+
node instanceof IndexOutNode or
363386
node.asExpr() instanceof ParenExpr or
364387
nodeIsHidden(node.(PostUpdateNode).getPreUpdateNode())
365388
}
@@ -552,12 +575,6 @@ module RustDataFlow implements InputSig<Location> {
552575
access = c.(FieldContent).getAnAccess()
553576
)
554577
or
555-
exists(IndexExpr arr |
556-
c instanceof ElementContent and
557-
node1.asExpr() = arr.getBase() and
558-
node2.asExpr() = arr
559-
)
560-
or
561578
exists(ForExpr for |
562579
c instanceof ElementContent and
563580
node1.asExpr() = for.getIterable() and
@@ -583,6 +600,12 @@ module RustDataFlow implements InputSig<Location> {
583600
node2.asExpr() = deref
584601
)
585602
or
603+
exists(IndexExpr index |
604+
c instanceof ReferenceContent and
605+
node1.(IndexOutNode).getIndexExpr() = index and
606+
node2.asExpr() = index
607+
)
608+
or
586609
// Read from function return
587610
exists(DataFlowCall call |
588611
lambdaCall(call, _, node1) and
@@ -644,13 +667,27 @@ module RustDataFlow implements InputSig<Location> {
644667
}
645668

646669
pragma[nomagic]
647-
private predicate referenceAssignment(Node node1, Node node2, ReferenceContent c) {
648-
exists(AssignmentExpr assignment, PrefixExpr deref |
649-
assignment.getLhs() = deref and
650-
deref.getOperatorName() = "*" and
670+
private predicate referenceAssignment(
671+
Node node1, Node node2, Expr e, boolean clears, ReferenceContent c
672+
) {
673+
exists(AssignmentExpr assignment, Expr lhs |
674+
assignment.getLhs() = lhs and
651675
node1.asExpr() = assignment.getRhs() and
652-
node2.asExpr() = deref.getExpr() and
653676
exists(c)
677+
|
678+
lhs =
679+
any(DerefExpr de |
680+
de = node2.(DerefOutNode).getDerefExpr() and
681+
e = de.getExpr()
682+
) and
683+
clears = true
684+
or
685+
lhs =
686+
any(IndexExpr ie |
687+
ie = node2.(IndexOutNode).getIndexExpr() and
688+
e = ie.getBase() and
689+
clears = false
690+
)
654691
)
655692
}
656693

@@ -694,14 +731,14 @@ module RustDataFlow implements InputSig<Location> {
694731
or
695732
fieldAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
696733
or
697-
referenceAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), c)
734+
referenceAssignment(node1, node2.(PostUpdateNode).getPreUpdateNode(), _, _, c)
698735
or
699-
exists(AssignmentExpr assignment, IndexExpr index |
700-
c instanceof ElementContent and
701-
assignment.getLhs() = index and
702-
node1.asExpr() = assignment.getRhs() and
703-
node2.(PostUpdateNode).getPreUpdateNode().asExpr() = index.getBase()
704-
)
736+
indexAssignment(any(AssignmentExpr ae), _, node1, node2, c)
737+
or
738+
// Compund assignment like `a[i] += rhs` are modeled as a store step from `rhs`
739+
// to `[post] a[i]`, followed by a taint step into `[post] a`.
740+
indexAssignment(any(CompoundAssignmentExpr cae),
741+
node2.(PostUpdateNode).getPreUpdateNode().asExpr(), node1, _, c)
705742
or
706743
referenceExprToExpr(node1, node2, c)
707744
or
@@ -738,7 +775,7 @@ module RustDataFlow implements InputSig<Location> {
738775
predicate clearsContent(Node n, ContentSet cs) {
739776
fieldAssignment(_, n, cs.(SingletonContentSet).getContent())
740777
or
741-
referenceAssignment(_, n, cs.(SingletonContentSet).getContent())
778+
referenceAssignment(_, _, n.asExpr(), true, cs.(SingletonContentSet).getContent())
742779
or
743780
FlowSummaryImpl::Private::Steps::summaryClearsContent(n.(FlowSummaryNode).getSummaryNode(), cs)
744781
or
@@ -982,9 +1019,7 @@ private module Cached {
9821019
newtype TDataFlowCall =
9831020
TCall(Call call) {
9841021
Stages::DataFlowStage::ref() and
985-
call.hasEnclosingCfgScope() and
986-
// TODO: Handle index expressions as calls in data flow.
987-
not call instanceof IndexExpr
1022+
call.hasEnclosingCfgScope()
9881023
} or
9891024
TSummaryCall(
9901025
FlowSummaryImpl::Public::SummarizedCallable c, FlowSummaryImpl::Private::SummaryNode receiver

rust/ql/lib/codeql/rust/dataflow/internal/Node.qll

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ final private class ExprOutNode extends ExprNode, OutNode {
350350
ExprOutNode() {
351351
exists(Call call |
352352
call = this.asExpr() and
353-
not call instanceof DerefExpr // Handled by `DerefOutNode`
353+
not call instanceof DerefExpr and // Handled by `DerefOutNode`
354+
not call instanceof IndexExpr // Handled by `IndexOutNode`
354355
)
355356
}
356357

@@ -387,6 +388,32 @@ class DerefOutNode extends OutNode, TDerefOutNode {
387388
override string toString() { result = de.toString() + " [pre-dereferenced]" }
388389
}
389390

391+
/**
392+
* A node that represents the value of a `x[y]` expression _before_ implicit
393+
* dereferencing:
394+
*
395+
* `x[y]` equivalent to `*x.index(y)`, and this node represents the
396+
* `x.index(y)` part.
397+
*/
398+
class IndexOutNode extends OutNode, TIndexOutNode {
399+
IndexExpr ie;
400+
401+
IndexOutNode() { this = TIndexOutNode(ie, false) }
402+
403+
IndexExpr getIndexExpr() { result = ie }
404+
405+
override CfgScope getCfgScope() { result = ie.getEnclosingCfgScope() }
406+
407+
override DataFlowCall getCall(ReturnKind kind) {
408+
result.asCall() = ie and
409+
kind = TNormalReturnKind()
410+
}
411+
412+
override Location getLocation() { result = ie.getLocation() }
413+
414+
override string toString() { result = ie.toString() + " [pre-dereferenced]" }
415+
}
416+
390417
final class SummaryOutNode extends FlowSummaryNode, OutNode {
391418
private DataFlowCall call;
392419
private ReturnKind kind_;
@@ -476,6 +503,18 @@ class DerefOutPostUpdateNode extends PostUpdateNode, TDerefOutNode {
476503
override Location getLocation() { result = de.getLocation() }
477504
}
478505

506+
class IndexOutPostUpdateNode extends PostUpdateNode, TIndexOutNode {
507+
IndexExpr ie;
508+
509+
IndexOutPostUpdateNode() { this = TIndexOutNode(ie, true) }
510+
511+
override IndexOutNode getPreUpdateNode() { result = TIndexOutNode(ie, false) }
512+
513+
override CfgScope getCfgScope() { result = ie.getEnclosingCfgScope() }
514+
515+
override Location getLocation() { result = ie.getLocation() }
516+
}
517+
479518
final class SummaryPostUpdateNode extends FlowSummaryNode, PostUpdateNode {
480519
private FlowSummaryNode pre;
481520

@@ -514,7 +553,8 @@ newtype TNode =
514553
TExprPostUpdateNode(Expr e) {
515554
e.hasEnclosingCfgScope() and
516555
(
517-
isArgumentForCall(e, _, _)
556+
isArgumentForCall(e, _, _) and
557+
not (e = any(CompoundAssignmentExpr cae).getLhs() and e instanceof VariableAccess)
518558
or
519559
lambdaCallExpr(_, _, e)
520560
or
@@ -526,7 +566,6 @@ newtype TNode =
526566
or
527567
e =
528568
[
529-
any(IndexExpr i).getBase(), //
530569
any(FieldExpr access).getContainer(), //
531570
any(TryExpr try).getExpr(), //
532571
any(AwaitExpr a).getExpr(), //
@@ -542,6 +581,7 @@ newtype TNode =
542581
borrow = true
543582
} or
544583
TDerefOutNode(DerefExpr de, Boolean isPost) or
584+
TIndexOutNode(IndexExpr ie, Boolean isPost) or
545585
TSsaNode(SsaImpl::DataFlowIntegration::SsaNode node) or
546586
TFlowSummaryNode(FlowSummaryImpl::Private::SummaryNode sn) {
547587
forall(AstNode n | n = sn.getSinkElement() or n = sn.getSourceElement() |

rust/ql/lib/codeql/rust/dataflow/internal/TaintTrackingImpl.qll

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ module RustTaintTracking implements InputSig<Location, RustDataFlow> {
6565
or
6666
succ.(Node::PostUpdateNode).getPreUpdateNode().asExpr() =
6767
getPostUpdateReverseStep(pred.(Node::PostUpdateNode).getPreUpdateNode().asExpr(), false)
68+
or
69+
indexAssignment(any(CompoundAssignmentExpr cae),
70+
pred.(Node::PostUpdateNode).getPreUpdateNode().asExpr(), _, succ, _)
6871
)
6972
or
7073
FlowSummaryImpl::Private::Steps::summaryLocalStep(pred.(Node::FlowSummaryNode).getSummaryNode(),

rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,12 @@ module Impl {
676676
predicate isCapture() { this.getEnclosingCfgScope() != v.getEnclosingCfgScope() }
677677
}
678678

679-
/** Holds if `e` occurs in the LHS of an assignment or compound assignment. */
680-
private predicate assignmentExprDescendant(AssignmentExpr ae, Expr e) {
681-
e = ae.getLhs()
679+
/** Holds if `e` occurs in the LHS of an assignment operation. */
680+
predicate assignmentOperationDescendant(AssignmentOperation ao, Expr e) {
681+
e = ao.getLhs()
682682
or
683683
exists(Expr mid |
684-
assignmentExprDescendant(ae, mid) and
684+
assignmentOperationDescendant(ao, mid) and
685685
getImmediateParentAdj(e) = mid and
686686
not mid instanceof DerefExpr and
687687
not mid instanceof FieldExpr and
@@ -696,7 +696,7 @@ module Impl {
696696
cached
697697
VariableWriteAccess() {
698698
Cached::ref() and
699-
assignmentExprDescendant(ae, this)
699+
assignmentOperationDescendant(ae, this)
700700
}
701701

702702
/** Gets the assignment expression that has this write access in the left-hand side. */

rust/ql/lib/codeql/rust/frameworks/stdlib/core.model.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ extensions:
66
# Builtin deref
77
- ["<& as core::ops::deref::Deref>::deref", "Argument[self].Reference", "ReturnValue", "value", "manual"]
88
- ["<&mut as core::ops::deref::Deref>::deref", "Argument[self].Reference", "ReturnValue", "value", "manual"]
9+
# Index
10+
- ["<_ as core::ops::index::Index>::index", "Argument[self].Reference.Element", "ReturnValue.Reference", "value", "manual"]
11+
- ["<_ as core::ops::index::IndexMut>::index_mut", "Argument[self].Reference.Element", "ReturnValue.Reference", "value", "manual"]
912
# Arithmetic
1013
- ["<_ as core::ops::arith::Add>::add", "Argument[self]", "ReturnValue", "taint", "manual"]
1114
- ["<_ as core::ops::arith::Add>::add", "Argument[0]", "ReturnValue", "taint", "manual"]

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ private import TypeMention
99
private import typeinference.FunctionType
1010
private import typeinference.FunctionOverloading as FunctionOverloading
1111
private import typeinference.BlanketImplementation as BlanketImplementation
12+
private import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl
1213
private import codeql.rust.internal.CachedStages
1314
private import codeql.typeinference.internal.TypeInference
1415
private import codeql.rust.frameworks.stdlib.Stdlib
@@ -672,7 +673,7 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
672673

673674
/**
674675
* Holds if `child` is a child of `parent`, and the Rust compiler applies [least
675-
* upper bound (LUB) coercion](1) to infer the type of `parent` from the type of
676+
* upper bound (LUB) coercion][1] to infer the type of `parent` from the type of
676677
* `child`.
677678
*
678679
* In this case, we want type information to only flow from `child` to `parent`,
@@ -1645,9 +1646,14 @@ private module MethodResolution {
16451646
}
16461647

16471648
private class MethodCallIndexExpr extends MethodCall instanceof IndexExpr {
1649+
private predicate isInMutableContext() {
1650+
// todo: does not handle all cases yet
1651+
VariableImpl::assignmentOperationDescendant(_, this)
1652+
}
1653+
16481654
pragma[nomagic]
16491655
override predicate hasNameAndArity(string name, int arity) {
1650-
name = "index" and
1656+
(if this.isInMutableContext() then name = "index_mut" else name = "index") and
16511657
arity = 1
16521658
}
16531659

@@ -1661,7 +1667,11 @@ private module MethodResolution {
16611667

16621668
override predicate supportsAutoDerefAndBorrow() { any() }
16631669

1664-
override Trait getTrait() { result.getCanonicalPath() = "core::ops::index::Index" }
1670+
override Trait getTrait() {
1671+
if this.isInMutableContext()
1672+
then result.getCanonicalPath() = "core::ops::index::IndexMut"
1673+
else result.getCanonicalPath() = "core::ops::index::Index"
1674+
}
16651675
}
16661676

16671677
private class MethodCallCallExpr extends MethodCall instanceof CallExpr {

0 commit comments

Comments
 (0)