[Mlir-commits] [mlir] Add options to generate-runtime-verification to enable faster pass running (PR #160331)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 23 08:46:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Hanchenng Wu (HanchengWu)
<details>
<summary>Changes</summary>
The pass generate-runtime-verification generates additional runtime op verification checks.
Currently, the pass is extremely expensive. For example, with a mobilenet v2 ssd network(converted to mlir), running this pass alone will take 30 minutes. The same observation has been made to other networks as small as 5 Mb.
The culprit is this line "op->print(stream, flags);" in function "RuntimeVerifiableOpInterface::generateErrorMessage" in File mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp.
As we are printing the op with all the names of the operands in the middle end, we are constructing a new SSANameState for each op->print(...) call. Thus, we are doing a new SSA analysis for each error message printed.
Perf profiling shows that 98% percent of the time is spent in the constructor of SSANameState.
This change add verbose options to generate-runtime-verification pass.
verbose 0: print only source location with error message.
verbose 1: print source location and operation name and operand types with error message.
verbose 2: print the full op, including the name of the operands.
verbose 2 is the current behavior and is very expensive. I still keep the default as verbose 2.
When we switch from verbose 2 to verbose 0/1, we see below improvements.
For mlir imported from mobileNet v2 ssd, the running time of the pass is reduced from 32 mintues to 21 seconds.
For another small network (only 5MB size), the running time of the pass is reduced from 15 minutes to 4 seconds.
---
Full diff: https://github.com/llvm/llvm-project/pull/160331.diff
8 Files Affected:
- (modified) mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td (+4-2)
- (modified) mlir/include/mlir/Transforms/Passes.h (+1)
- (modified) mlir/include/mlir/Transforms/Passes.td (+8)
- (modified) mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp (+5-3)
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+25-21)
- (modified) mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp (+20-4)
- (modified) mlir/lib/Transforms/GenerateRuntimeVerification.cpp (+10-1)
- (modified) mlir/test/Dialect/Linalg/runtime-verification.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
index 6fd0df59d9d2e..e5c9336c8d8dc 100644
--- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
+++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
@@ -32,14 +32,16 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
/*retTy=*/"void",
/*methodName=*/"generateRuntimeVerification",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::Location":$loc)
+ "::mlir::Location":$loc,
+ "unsigned":$verboseLevel)
>,
];
let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
- static std::string generateErrorMessage(Operation *op, const std::string &msg);
+ static std::string generateErrorMessage(Operation *op, const std::string &msg,
+ unsigned verboseLevel = 0);
}];
}
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..58ba0892df113 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -46,6 +46,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_SYMBOLPRIVATIZE
#define GEN_PASS_DECL_TOPOLOGICALSORT
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
+#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
/// Creates an instance of the Canonicalizer pass, configured with default
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index a39ab77fc8fb3..3d643d8a168db 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -271,8 +271,16 @@ def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
passes that are suspected to introduce faulty IR.
}];
let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+ let options = [
+ Option<"verboseLevel", "verbose-level", "unsigned", /*default=*/"2",
+ "Verbosity level for runtime verification messages: "
+ "0 = Minimum (only source location), "
+ "1 = Basic (include operation type and operand type), "
+ "2 = Detailed (include full operation details, names, types, shapes, etc.)">
+ ];
}
+
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index b30182dc84079..608a6801af267 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -32,7 +32,7 @@ struct StructuredOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
StructuredOpInterface<T>, T> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc);
@@ -73,7 +73,8 @@ struct StructuredOpInterface
auto msg = RuntimeVerifiableOpInterface::generateErrorMessage(
linalgOp, "unexpected negative result on dimension #" +
std::to_string(dim) + " of input/output operand #" +
- std::to_string(opOperand.getOperandNumber()));
+ std::to_string(opOperand.getOperandNumber()),
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
// Generate:
@@ -104,7 +105,8 @@ struct StructuredOpInterface
linalgOp, "dimension #" + std::to_string(dim) +
" of input/output operand #" +
std::to_string(opOperand.getOperandNumber()) +
- " is incompatible with inferred dimension size");
+ " is incompatible with inferred dimension size",
+ verboseLevel);
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index cd92026562da9..d8a7a89a3fbe7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -39,7 +39,7 @@ struct AssumeAlignmentOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto assumeOp = cast<AssumeAlignmentOp>(op);
Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
loc, assumeOp.getMemref());
@@ -53,7 +53,8 @@ struct AssumeAlignmentOpInterface
loc, isAligned,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "memref is not aligned to " +
- std::to_string(assumeOp.getAlignment())));
+ std::to_string(assumeOp.getAlignment()),
+ verboseLevel));
}
};
@@ -61,7 +62,7 @@ struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto castOp = cast<CastOp>(op);
auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
@@ -79,8 +80,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(
loc, isSameRank,
- RuntimeVerifiableOpInterface::generateErrorMessage(op,
- "rank mismatch"));
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "rank mismatch", verboseLevel));
}
// Get source offset and strides. We do not have an op to get offsets and
@@ -119,7 +120,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "size mismatch of dim " + std::to_string(it.index())));
+ op, "size mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
// Get result offset and strides.
@@ -139,7 +141,7 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset mismatch"));
+ op, "offset mismatch", verboseLevel));
}
// Check strides.
@@ -157,7 +159,8 @@ struct CastOpInterface
builder.create<cf::AssertOp>(
loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "stride mismatch of dim " + std::to_string(it.index())));
+ op, "stride mismatch of dim " + std::to_string(it.index()),
+ verboseLevel));
}
}
};
@@ -166,7 +169,7 @@ struct CopyOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface,
CopyOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto copyOp = cast<CopyOp>(op);
BaseMemRefType sourceType = copyOp.getSource().getType();
BaseMemRefType targetType = copyOp.getTarget().getType();
@@ -201,7 +204,7 @@ struct CopyOpInterface
loc, sameDimSize,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size of " + std::to_string(i) +
- "-th source/target dim does not match"));
+ "-th source/target dim does not match", verboseLevel));
}
}
};
@@ -210,14 +213,14 @@ struct DimOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
DimOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto dimOp = cast<DimOp>(op);
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<cf::AssertOp>(
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "index is out of bounds"));
+ op, "index is out of bounds", verboseLevel));
}
};
@@ -228,7 +231,7 @@ struct LoadStoreOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto loadStoreOp = cast<LoadStoreOp>(op);
auto memref = loadStoreOp.getMemref();
@@ -251,7 +254,7 @@ struct LoadStoreOpInterface
builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "out-of-bounds access"));
+ op, "out-of-bounds access", verboseLevel));
}
};
@@ -295,7 +298,7 @@ struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
@@ -323,7 +326,8 @@ struct ReinterpretCastOpInterface
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
+ "result of reinterpret_cast is out-of-bounds of the base memref",
+ verboseLevel));
}
};
@@ -331,7 +335,7 @@ struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto subView = cast<SubViewOp>(op);
MemRefType sourceType = subView.getSource().getType();
@@ -357,7 +361,7 @@ struct SubViewOpInterface
builder.create<cf::AssertOp>(
loc, offsetInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
- op, "offset " + std::to_string(i) + " is out-of-bounds"));
+ op, "offset " + std::to_string(i) + " is out-of-bounds", verboseLevel));
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
@@ -371,7 +375,7 @@ struct SubViewOpInterface
loc, lastPosInBounds,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview runs out-of-bounds along dimension " +
- std::to_string(i)));
+ std::to_string(i), verboseLevel));
}
}
};
@@ -380,7 +384,7 @@ struct ExpandShapeOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
ExpandShapeOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
+ Location loc, unsigned verboseLevel) const {
auto expandShapeOp = cast<ExpandShapeOp>(op);
// Verify that the expanded dim sizes are a product of the collapsed dim
@@ -414,7 +418,7 @@ struct ExpandShapeOpInterface
loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
- "divide src dim evenly"));
+ "divide src dim evenly", verboseLevel));
}
}
};
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
index 8aa194befb420..8b54ed1dc3780 100644
--- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
+++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
@@ -15,7 +15,7 @@ class OpBuilder;
/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
- const std::string &msg) {
+ const std::string &msg, unsigned verboseLevel) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
@@ -26,9 +26,25 @@ RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
- op->print(stream, flags);
- stream << "\n^ " << msg;
- stream << "\nLocation: ";
+ if (verboseLevel == 2){
+ // print full op including operand names, very expensive
+ op->print(stream, flags);
+ stream << "\n " << msg;
+ }else if (verboseLevel == 1){
+ // print op name and operand types
+ stream << "Op: " << op->getName().getStringRef() << "\n";
+ stream << "Operand Types:";
+ for (const auto &operand : op->getOpOperands()) {
+ stream << " " << operand.get().getType();
+ }
+ stream << "\n" << msg;
+ stream << "Result Types:";
+ for (const auto &result : op->getResults()) {
+ stream << " " << result.getType();
+ }
+ stream << "\n" << msg;
+ }
+ stream << "^\nLocation: ";
op->getLoc().print(stream);
return buffer;
}
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..7a54ce667c6ad 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,6 +28,14 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // Check verboseLevel is in range [0, 2].
+ if (verboseLevel > 2) {
+ getOperation()->emitError(
+ "generate-runtime-verification pass: set verboseLevel to 0, 1 or 2");
+ signalPassFailure();
+ return;
+ }
+
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
@@ -39,7 +47,8 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
+ verboseLevel);
};
}
diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir
index a4f29d8457e58..238169adf496e 100644
--- a/mlir/test/Dialect/Linalg/runtime-verification.mlir
+++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir
@@ -1,13 +1,25 @@
// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=1" | FileCheck %s --check-prefix=VERBOSE1
+// RUN: mlir-opt %s --generate-runtime-verification="verbose-level=0" | FileCheck %s --check-prefix=VERBOSE0
// Most of the tests for linalg runtime-verification are implemented as integration tests.
#identity = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @static_dims
+// VERBOSE1-LABEL: @static_dims
+// VERBOSE0-LABEL: @static_dims
func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) {
// CHECK: %[[TRUE:.*]] = index.bool.constant true
// CHECK: cf.assert %[[TRUE]]
+ // VERBOSE1: %[[TRUE:.*]] = index.bool.constant true
+ // VERBOSE1: cf.assert %[[TRUE]]
+ // VERBOSE1: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE1: Result Types
+ // VERBOSE1: Location: loc
+ // VERBOSE0-NOT: Operand Types: tensor<5xf32> tensor<5xf32> tensor<5xf32>
+ // VERBOSE0-NOT: Result Types
+ // VERBOSE0: Location: loc
%result = tensor.empty() : tensor<5xf32>
%0 = linalg.generic {
indexing_maps = [#identity, #identity, #identity],
@@ -26,9 +38,11 @@ func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5x
#map = affine_map<() -> ()>
// CHECK-LABEL: @scalars
+// VERBOSE1-LABEL: @scalars
func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// No runtime checks are required if the operands are all scalars
// CHECK-NOT: cf.assert
+ // VERBOSE1-NOT: cf.assert
%result = tensor.empty() : tensor<f32>
%0 = linalg.generic {
indexing_maps = [#map, #map, #map],
``````````
</details>
https://github.com/llvm/llvm-project/pull/160331
More information about the Mlir-commits
mailing list