[Mlir-commits] [mlir] [mlir][SCF] Add folding for IndexSwitchOp (PR #70924)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 1 06:27:36 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/70924

>From 8915715b28e5f074cb4f05ad234fd89e2cce0c77 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 1 Nov 2023 11:39:29 +0000
Subject: [PATCH] [mlir][SCF] Add folding for IndexSwitchOp

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  1 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 29 ++++++++++++++++
 mlir/test/Dialect/SCF/canonicalize.mlir       | 33 +++++++++++++++++++
 .../SCF/for-loop-canonicalization.mlir        | 14 ++++----
 4 files changed, 70 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 38937fe28949436..2c5abe7a63ac44d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1126,6 +1126,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     Block &getCaseBlock(unsigned idx);
   }];
 
+  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bc33fe2a9a01079..646284f8a9db435 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4166,6 +4166,35 @@ void IndexSwitchOp::getRegionInvocationBounds(
     bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
 }
 
+LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
+                                  SmallVectorImpl<OpFoldResult> &results) {
+  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
+  if (!maybeCst.has_value())
+    return failure();
+  int64_t cst = *maybeCst;
+  int64_t caseIdx, e = getNumCases();
+  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+    if (cst == getCases()[caseIdx])
+      break;
+  }
+
+  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
+                                        : getDefaultRegion();
+  Block &source = r.front();
+  results.assign(source.getTerminator()->getOperands().begin(),
+                 source.getTerminator()->getOperands().end());
+
+  Block *pDestination = (*this)->getBlock();
+  if (!pDestination)
+    return failure();
+  Block::iterator insertionPoint = (*this)->getIterator();
+  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
+                                       source.getOperations().begin(),
+                                       std::prev(source.getOperations().end()));
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e55be6127fe2417..9dbf8d5dab11ae6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1756,3 +1756,36 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
 // CHECK:         parallel_insert_slice
 // CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
 // CHECK:         tensor.cast
+
+// -----
+
+func.func @index_switch_fold() -> (f32, f32) {
+  %switch_cst = arith.constant 1: index
+  %0 = scf.index_switch %switch_cst -> f32
+  case 1 {
+    %y = arith.constant 1.0 : f32
+    scf.yield %y : f32
+  }
+  default {
+    %y = arith.constant 42.0 : f32
+    scf.yield %y : f32
+  }
+  
+  %switch_cst_2 = arith.constant 2: index
+  %1 = scf.index_switch %switch_cst_2 -> f32
+  case 0 {
+    %y = arith.constant 0.0 : f32
+    scf.yield %y : f32
+  }
+  default {
+    %y = arith.constant 42.0 : f32
+    scf.yield %y : f32
+  }
+  
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @index_switch_fold()
+//  CHECK-NEXT:   %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+//  CHECK-NEXT:   %[[c42:.*]] = arith.constant 4.200000e+01 : f32
+//  CHECK-NEXT:   return %[[c1]], %[[c42]] : f32, f32
diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 83c236ea6b4a7d5..6fb475efcb6586f 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -394,14 +394,18 @@ func.func @regression_multiplication_with_sym(%A : memref<i64>) {
 
 // -----
 
+
 // Make sure min is transformed into zero.
 
-// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
-// CHECK: scf.index_switch %[[ZERO]] -> i1
+// CHECK-LABEL: func.func @func1()
+//       CHECK:   %[[ZERO:.+]] = arith.constant 0 : index
+//       CHECK:   call @foo(%[[ZERO]]) : (index) -> ()
 
 #map6 = affine_map<(d0, d1, d2) -> (d0 floordiv 64)>
 #map29 = affine_map<(d0, d1, d2) -> (d2 * 64 - 2, 5, (d1 mod 4) floordiv 8)>
 module {
+  func.func private @foo(%0 : index) -> ()
+
   func.func @func1() {
     %true = arith.constant true
     %c0 = arith.constant 0 : index
@@ -412,11 +416,7 @@ module {
     %alloc_249 = memref.alloc() : memref<7xf32>
     %135 = affine.apply #map6(%c15, %c0, %c14)
     %163 = affine.min #map29(%c5, %135, %c11)
-    %196 = scf.index_switch %163 -> i1
-    default {
-      memref.assume_alignment %alloc_249, 1 : memref<7xf32>
-      scf.yield %true : i1
-    }
+    func.call @foo(%163) : (index) -> ()
     return
   }
 }



More information about the Mlir-commits mailing list