[Mlir-commits] [mlir] 79f4143 - [mlir][Transforms] Dialect conversion: Move `hasRewrite` to expensive checks (#119848)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 13 02:26:16 PST 2024


Author: Matthias Springer
Date: 2024-12-13T11:26:12+01:00
New Revision: 79f41434460d3305c889a6483ea59f1e3ea19b5a

URL: https://github.com/llvm/llvm-project/commit/79f41434460d3305c889a6483ea59f1e3ea19b5a
DIFF: https://github.com/llvm/llvm-project/commit/79f41434460d3305c889a6483ea59f1e3ea19b5a.diff

LOG: [mlir][Transforms] Dialect conversion: Move `hasRewrite` to expensive checks (#119848)

The dialect conversion has various checks that detect incorrect API
usage in patterns. One of these checks turned out to be quite expensive
(N*M complexity where N is the number of block rewrites and M is the
total number of rewrites) in NVIDIA-internal workloads: Checking that a
block is not converted multiple times.

This check iterates over the stack of all rewrites, which can be large.
We saw `hasRewrite` being called around 45000 times with an average
rewrite stack size of 500000.

This PR moves the check to `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
For consistency reasons, the other `hasRewrite`-based check is also
moved there.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index cedf645e2985da..1607740a1ee076 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -714,6 +714,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 };
 } // namespace
 
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 /// Return "true" if there is an operation rewrite that matches the specified
 /// rewrite type and operation among the given rewrites.
 template <typename RewriteTy, typename R>
@@ -724,7 +725,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
   });
 }
 
-#ifndef NDEBUG
 /// Return "true" if there is a block rewrite that matches the specified
 /// rewrite type and block among the given rewrites.
 template <typename RewriteTy, typename R>
@@ -734,7 +734,7 @@ static bool hasRewrite(R &&rewrites, Block *block) {
     return rewriteTy && rewriteTy->getBlock() == block;
   });
 }
-#endif // NDEBUG
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
@@ -1292,9 +1292,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     ConversionPatternRewriter &rewriter, Block *block,
     const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // A block cannot be converted multiple times.
-  assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
-         "block was already converted");
+  if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block))
+    llvm::report_fatal_error("block was already converted");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
   OpBuilder::InsertionGuard g(rewriter);
 
   // If no arguments are being changed or added, there is nothing to do.
@@ -2236,9 +2239,9 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
                                           ConversionPatternRewriter &rewriter,
                                           RewriterState &curState) {
   auto &impl = rewriter.getImpl();
-
-#ifndef NDEBUG
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
+
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // Check that the root was either replaced or updated in place.
   auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
   auto replacedRoot = [&] {
@@ -2247,9 +2250,9 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   auto updatedRootInPlace = [&] {
     return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
   };
-  assert((replacedRoot() || updatedRootInPlace()) &&
-         "expected pattern to replace the root operation");
-#endif // NDEBUG
+  if (!replacedRoot() && !updatedRootInPlace())
+    llvm::report_fatal_error("expected pattern to replace the root operation");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
   RewriterState newState = impl.getCurrentState();


        


More information about the Mlir-commits mailing list