Skip to content

Commit 6b7b0ab

Browse files
authored
Enable pass instrumentation to signal failures. (#163126)
Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere. It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes. Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .
1 parent 0f2f9e1 commit 6b7b0ab

File tree

4 files changed

+132
-13
lines changed

4 files changed

+132
-13
lines changed

mlir/include/mlir/Pass/Pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <optional>
1818

1919
namespace mlir {
20+
class PassInstrumentation;
2021
namespace detail {
2122
class OpToOpPassAdaptor;
2223
struct OpPassManagerImpl;
@@ -341,6 +342,9 @@ class Pass {
341342

342343
/// Allow access to 'passOptions'.
343344
friend class PassInfo;
345+
346+
/// Allow access to 'signalPassFailure'.
347+
friend class PassInstrumentation;
344348
};
345349

346350
//===----------------------------------------------------------------------===//

mlir/include/mlir/Pass/PassInstrumentation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class PassInstrumentation {
8080
/// name of the analysis that was computed, its TypeID, as well as the
8181
/// current operation being analyzed.
8282
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
83+
84+
/// Helper method to enable analysis to signal pass failure. Used, for
85+
/// example, when pre- or post-conditions fail.
86+
void signalPassFailure(Pass *pass);
8387
};
8488

8589
/// This class holds a collection of PassInstrumentation objects, and invokes

mlir/lib/Pass/Pass.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -599,17 +599,21 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
599599
if (pi)
600600
pi->runBeforePass(pass, op);
601601

602-
bool passFailed = false;
603-
op->getContext()->executeAction<PassExecutionAction>(
604-
[&]() {
605-
// Invoke the virtual runOnOperation method.
606-
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
607-
adaptor->runOnOperation(verifyPasses);
608-
else
609-
pass->runOnOperation();
610-
passFailed = pass->passState->irAndPassFailed.getInt();
611-
},
612-
{op}, *pass);
602+
// Pass instrumentation can use pass failure to flag unmet invariants
603+
// (preconditions) of the pass. Skip running pass if in failure state.
604+
bool passFailed = pass->passState->irAndPassFailed.getInt();
605+
if (!passFailed) {
606+
op->getContext()->executeAction<PassExecutionAction>(
607+
[&]() {
608+
// Invoke the virtual runOnOperation method.
609+
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
610+
adaptor->runOnOperation(verifyPasses);
611+
else
612+
pass->runOnOperation();
613+
passFailed = pass->passState->irAndPassFailed.getInt();
614+
},
615+
{op}, *pass);
616+
}
613617

614618
// Invalidate any non preserved analyses.
615619
am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +644,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
640644

641645
// Instrument after the pass has run.
642646
if (pi) {
643-
if (passFailed)
647+
if (passFailed) {
644648
pi->runAfterPassFailed(pass, op);
645-
else
649+
} else {
646650
pi->runAfterPass(pass, op);
651+
passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
652+
}
647653
}
648654

649655
// Return if the pass signaled a failure.
@@ -1198,6 +1204,10 @@ void PassInstrumentation::runBeforePipeline(
11981204
void PassInstrumentation::runAfterPipeline(
11991205
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
12001206

1207+
void PassInstrumentation::signalPassFailure(Pass *pass) {
1208+
pass->signalPassFailure();
1209+
}
1210+
12011211
//===----------------------------------------------------------------------===//
12021212
// PassInstrumentor
12031213
//===----------------------------------------------------------------------===//

mlir/unittests/Pass/PassManagerTest.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/BuiltinOps.h"
1515
#include "mlir/IR/Diagnostics.h"
1616
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Pass/PassInstrumentation.h"
1718
#include "gtest/gtest.h"
1819

1920
#include <memory>
@@ -117,6 +118,106 @@ struct AddSecondAttrFunctionPass
117118
}
118119
};
119120

121+
/// PassInstrumentation to count pass callbacks and signal pass failures.
122+
struct TestPassInstrumentation : public PassInstrumentation {
123+
int beforePassCallbackCount = 0;
124+
int afterPassCallbackCount = 0;
125+
int afterPassFailedCallbackCount = 0;
126+
127+
bool failBeforePass = false;
128+
bool failAfterPass = false;
129+
130+
void runBeforePass(Pass *pass, Operation *op) override {
131+
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
132+
return;
133+
134+
++beforePassCallbackCount;
135+
if (failBeforePass)
136+
signalPassFailure(pass);
137+
}
138+
void runAfterPass(Pass *pass, Operation *op) override {
139+
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
140+
return;
141+
142+
++afterPassCallbackCount;
143+
if (failAfterPass)
144+
signalPassFailure(pass);
145+
}
146+
void runAfterPassFailed(Pass *pass, Operation *op) override {
147+
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
148+
return;
149+
150+
++afterPassFailedCallbackCount;
151+
}
152+
};
153+
154+
TEST(PassManagerTest, PassInstrumentation) {
155+
MLIRContext context;
156+
context.loadDialect<func::FuncDialect>();
157+
Builder b(&context);
158+
159+
// Create a module with 1 function.
160+
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
161+
auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
162+
b.getFunctionType({}, {}));
163+
func.setPrivate();
164+
module->push_back(func);
165+
166+
struct InstrumentationCounts {
167+
int beforePass;
168+
int afterPass;
169+
int afterPassFailed;
170+
};
171+
172+
auto runInstrumentation =
173+
[&](bool failBefore,
174+
bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
175+
// Instantiate and run our pass.
176+
auto pm = PassManager::on<ModuleOp>(&context);
177+
auto instrumentation = std::make_unique<TestPassInstrumentation>();
178+
auto *instrumentationPtr = instrumentation.get();
179+
instrumentation->failBeforePass = failBefore;
180+
instrumentation->failAfterPass = failAfter;
181+
pm.addInstrumentation(std::move(instrumentation));
182+
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
183+
LogicalResult result = pm.run(module.get());
184+
185+
InstrumentationCounts counts = {
186+
instrumentationPtr->beforePassCallbackCount,
187+
instrumentationPtr->afterPassCallbackCount,
188+
instrumentationPtr->afterPassFailedCallbackCount};
189+
return {result, counts};
190+
};
191+
192+
for (bool failBefore : {false, true}) {
193+
for (bool failAfter : {false, true}) {
194+
auto [result, counts] = runInstrumentation(failBefore, failAfter);
195+
196+
InstrumentationCounts expected;
197+
if (failBefore) {
198+
EXPECT_TRUE(failed(result))
199+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
200+
expected = {/*beforePass=*/1, /*afterPass=*/0, /*afterPassFailed=*/1};
201+
} else if (failAfter) {
202+
EXPECT_TRUE(failed(result))
203+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
204+
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
205+
} else {
206+
EXPECT_TRUE(succeeded(result))
207+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
208+
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
209+
}
210+
211+
EXPECT_EQ(counts.beforePass, expected.beforePass)
212+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
213+
EXPECT_EQ(counts.afterPass, expected.afterPass)
214+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
215+
EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
216+
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
217+
}
218+
}
219+
}
220+
120221
TEST(PassManagerTest, ExecutionAction) {
121222
MLIRContext context;
122223
context.loadDialect<func::FuncDialect>();

0 commit comments

Comments
 (0)