[Mlir-commits] [mlir] [mlir][SCF] Add folding for IndexSwitchOp (PR #70924)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 1 06:18:02 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/70924
>From 4124257d90350503e92318ea3e5d06a86bb212a2 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 1 Nov 2023 11:39:29 +0000
Subject: [PATCH] [mlir][SCF] Add folding for IndexSwitchOp
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 1 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 29 +++++++++++++++++++
mlir/test/Dialect/SCF/canonicalize.mlir | 33 ++++++++++++++++++++++
3 files changed, 63 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 38937fe28949436..2c5abe7a63ac44d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1126,6 +1126,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
Block &getCaseBlock(unsigned idx);
}];
+ let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bc33fe2a9a01079..646284f8a9db435 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4166,6 +4166,35 @@ void IndexSwitchOp::getRegionInvocationBounds(
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
}
+LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
+ if (!maybeCst.has_value())
+ return failure();
+ int64_t cst = *maybeCst;
+ int64_t caseIdx, e = getNumCases();
+ for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+ if (cst == getCases()[caseIdx])
+ break;
+ }
+
+ Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
+ : getDefaultRegion();
+ Block &source = r.front();
+ results.assign(source.getTerminator()->getOperands().begin(),
+ source.getTerminator()->getOperands().end());
+
+ Block *pDestination = (*this)->getBlock();
+ if (!pDestination)
+ return failure();
+ Block::iterator insertionPoint = (*this)->getIterator();
+ pDestination->getOperations().splice(insertionPoint, source.getOperations(),
+ source.getOperations().begin(),
+ std::prev(source.getOperations().end()));
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e55be6127fe2417..9dbf8d5dab11ae6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1756,3 +1756,36 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// CHECK: parallel_insert_slice
// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
// CHECK: tensor.cast
+
+// -----
+
+func.func @index_switch_fold() -> (f32, f32) {
+ %switch_cst = arith.constant 1: index
+ %0 = scf.index_switch %switch_cst -> f32
+ case 1 {
+ %y = arith.constant 1.0 : f32
+ scf.yield %y : f32
+ }
+ default {
+ %y = arith.constant 42.0 : f32
+ scf.yield %y : f32
+ }
+
+ %switch_cst_2 = arith.constant 2: index
+ %1 = scf.index_switch %switch_cst_2 -> f32
+ case 0 {
+ %y = arith.constant 0.0 : f32
+ scf.yield %y : f32
+ }
+ default {
+ %y = arith.constant 42.0 : f32
+ scf.yield %y : f32
+ }
+
+ return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @index_switch_fold()
+// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[c42:.*]] = arith.constant 4.200000e+01 : f32
+// CHECK-NEXT: return %[[c1]], %[[c42]] : f32, f32
More information about the Mlir-commits
mailing list