[Mlir-commits] [mlir] [mlir][SCF] Add folding for IndexSwitchOp (PR #70924)

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 1 06:18:02 PDT 2023


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/70924

>From 4124257d90350503e92318ea3e5d06a86bb212a2 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 1 Nov 2023 11:39:29 +0000
Subject: [PATCH] [mlir][SCF] Add folding for IndexSwitchOp

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |  1 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 29 +++++++++++++++++++
 mlir/test/Dialect/SCF/canonicalize.mlir    | 33 ++++++++++++++++++++++
 3 files changed, 63 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 38937fe28949436..2c5abe7a63ac44d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1126,6 +1126,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     Block &getCaseBlock(unsigned idx);
   }];
 
+  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bc33fe2a9a01079..646284f8a9db435 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4166,6 +4166,35 @@ void IndexSwitchOp::getRegionInvocationBounds(
     bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
 }
 
+LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
+                                  SmallVectorImpl<OpFoldResult> &results) {
+  std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
+  if (!maybeCst.has_value())
+    return failure();
+  int64_t cst = *maybeCst;
+  int64_t caseIdx, e = getNumCases();
+  for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+    if (cst == getCases()[caseIdx])
+      break;
+  }
+
+  Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
+                                        : getDefaultRegion();
+  Block &source = r.front();
+  results.assign(source.getTerminator()->getOperands().begin(),
+                 source.getTerminator()->getOperands().end());
+
+  Block *pDestination = (*this)->getBlock();
+  if (!pDestination)
+    return failure();
+  Block::iterator insertionPoint = (*this)->getIterator();
+  pDestination->getOperations().splice(insertionPoint, source.getOperations(),
+                                       source.getOperations().begin(),
+                                       std::prev(source.getOperations().end()));
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e55be6127fe2417..9dbf8d5dab11ae6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1756,3 +1756,36 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
 // CHECK:         parallel_insert_slice
 // CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
 // CHECK:         tensor.cast
+
+// -----
+
+func.func @index_switch_fold() -> (f32, f32) {
+  %switch_cst = arith.constant 1: index
+  %0 = scf.index_switch %switch_cst -> f32
+  case 1 {
+    %y = arith.constant 1.0 : f32
+    scf.yield %y : f32
+  }
+  default {
+    %y = arith.constant 42.0 : f32
+    scf.yield %y : f32
+  }
+  
+  %switch_cst_2 = arith.constant 2: index
+  %1 = scf.index_switch %switch_cst_2 -> f32
+  case 0 {
+    %y = arith.constant 0.0 : f32
+    scf.yield %y : f32
+  }
+  default {
+    %y = arith.constant 42.0 : f32
+    scf.yield %y : f32
+  }
+  
+  return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @index_switch_fold()
+//  CHECK-NEXT:   %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+//  CHECK-NEXT:   %[[c42:.*]] = arith.constant 4.200000e+01 : f32
+//  CHECK-NEXT:   return %[[c1]], %[[c42]] : f32, f32



More information about the Mlir-commits mailing list