[Mlir-commits] [mlir] [MLIR][SCF] fix scf.index_switch fold convergence (#98535) (PR #98680)

Keyi Zhang llvmlistbot at llvm.org
Mon Jul 15 16:36:31 PDT 2024


https://github.com/Kuree updated https://github.com/llvm/llvm-project/pull/98680

>From 552773b4aa9de1b2a4acf3ba6cc1fa3fb4dd5608 Mon Sep 17 00:00:00 2001
From: Keyi Zhang <keyi at efficient.computer>
Date: Fri, 12 Jul 2024 11:53:09 -0700
Subject: [PATCH] [MLIR][SCF] fix scf.index_switch fold convergence (#98535)

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |  2 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 57 +++++++++++++---------
 mlir/test/Dialect/SCF/canonicalize.mlir    | 18 +++++++
 3 files changed, 52 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index f35ea962bea16..acbcbae105dbf 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1159,7 +1159,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     Block &getCaseBlock(unsigned idx);
   }];
 
-  let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 907d7f794593d..4de8dacc0edbf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4297,33 +4297,42 @@ void IndexSwitchOp::getRegionInvocationBounds(
     bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
 }
 
-LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
-                                  SmallVectorImpl<OpFoldResult> &results) {
-  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
-  if (!maybeCst.has_value())
-    return failure();
-  int64_t cst = *maybeCst;
-  int64_t caseIdx, e = getNumCases();
-  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
-    if (cst == getCases()[caseIdx])
-      break;
-  }
+struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
+  using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
 
-  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
-                                        : getDefaultRegion();
-  Block &source = r.front();
-  results.assign(source.getTerminator()->getOperands().begin(),
-                 source.getTerminator()->getOperands().end());
+  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // If `op.getArg()` is a constant, select the region that matches with
+    // the constant value. Use the default region if no matche is found.
+    std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
+    if (!maybeCst.has_value())
+      return failure();
+    int64_t cst = *maybeCst;
+    int64_t caseIdx, e = op.getNumCases();
+    for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+      if (cst == op.getCases()[caseIdx])
+        break;
+    }
 
-  Block *pDestination = (*this)->getBlock();
-  if (!pDestination)
-    return failure();
-  Block::iterator insertionPoint = (*this)->getIterator();
-  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
-                                       source.getOperations().begin(),
-                                       std::prev(source.getOperations().end()));
+    Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
+                                             : op.getDefaultRegion();
+    Block &source = r.front();
+    Operation *terminator = source.getTerminator();
+    SmallVector<Value> results = terminator->getOperands();
 
-  return success();
+    rewriter.inlineBlockBefore(&source, op);
+    rewriter.eraseOp(terminator);
+    // Repalce the operation with a potentially empty list of results.
+    // Fold mechanism doesn't support the case where the result list is empty.
+    rewriter.replaceOp(op, results);
+
+    return success();
+  }
+};
+
+void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                MLIRContext *context) {
+  results.add<FoldConstantCase>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 459ccd73cfe61..268946803de7a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1846,3 +1846,21 @@ func.func @index_switch_fold() -> (f32, f32) {
 //  CHECK-NEXT:   %[[c1:.*]] = arith.constant 1.000000e+00 : f32
 //  CHECK-NEXT:   %[[c42:.*]] = arith.constant 4.200000e+01 : f32
 //  CHECK-NEXT:   return %[[c1]], %[[c42]] : f32, f32
+
+// -----
+
+func.func @index_switch_fold_no_res() {
+  %c1 = arith.constant 1 : index
+  scf.index_switch %c1
+  case 0 {
+    scf.yield
+  }
+  default {
+    "test.op"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @index_switch_fold_no_res()
+//  CHECK-NEXT: "test.op"() : () -> ()



More information about the Mlir-commits mailing list