[Mlir-commits] [mlir] [mlir][SCF] Add folding for IndexSwitchOp (PR #70924)
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 1 06:27:36 PDT 2023
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/70924
>From 8915715b28e5f074cb4f05ad234fd89e2cce0c77 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Wed, 1 Nov 2023 11:39:29 +0000
Subject: [PATCH] [mlir][SCF] Add folding for IndexSwitchOp
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 1 +
mlir/lib/Dialect/SCF/IR/SCF.cpp | 29 ++++++++++++++++
mlir/test/Dialect/SCF/canonicalize.mlir | 33 +++++++++++++++++++
.../SCF/for-loop-canonicalization.mlir | 14 ++++----
4 files changed, 70 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 38937fe28949436..2c5abe7a63ac44d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -1126,6 +1126,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
Block &getCaseBlock(unsigned idx);
}];
+ let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index bc33fe2a9a01079..646284f8a9db435 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4166,6 +4166,35 @@ void IndexSwitchOp::getRegionInvocationBounds(
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
}
+LogicalResult IndexSwitchOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(getArg());
+ if (!maybeCst.has_value())
+ return failure();
+ int64_t cst = *maybeCst;
+ int64_t caseIdx, e = getNumCases();
+ for (caseIdx = 0; caseIdx < e; ++caseIdx) {
+ if (cst == getCases()[caseIdx])
+ break;
+ }
+
+ Region &r = (caseIdx < getNumCases()) ? getCaseRegions()[caseIdx]
+ : getDefaultRegion();
+ Block &source = r.front();
+ results.assign(source.getTerminator()->getOperands().begin(),
+ source.getTerminator()->getOperands().end());
+
+ Block *pDestination = (*this)->getBlock();
+ if (!pDestination)
+ return failure();
+ Block::iterator insertionPoint = (*this)->getIterator();
+ pDestination->getOperations().splice(insertionPoint, source.getOperations(),
+ source.getOperations().begin(),
+ std::prev(source.getOperations().end()));
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e55be6127fe2417..9dbf8d5dab11ae6 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1756,3 +1756,36 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
// CHECK: parallel_insert_slice
// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
// CHECK: tensor.cast
+
+// -----
+
+func.func @index_switch_fold() -> (f32, f32) {
+ %switch_cst = arith.constant 1: index
+ %0 = scf.index_switch %switch_cst -> f32
+ case 1 {
+ %y = arith.constant 1.0 : f32
+ scf.yield %y : f32
+ }
+ default {
+ %y = arith.constant 42.0 : f32
+ scf.yield %y : f32
+ }
+
+ %switch_cst_2 = arith.constant 2: index
+ %1 = scf.index_switch %switch_cst_2 -> f32
+ case 0 {
+ %y = arith.constant 0.0 : f32
+ scf.yield %y : f32
+ }
+ default {
+ %y = arith.constant 42.0 : f32
+ scf.yield %y : f32
+ }
+
+ return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @index_switch_fold()
+// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[c42:.*]] = arith.constant 4.200000e+01 : f32
+// CHECK-NEXT: return %[[c1]], %[[c42]] : f32, f32
diff --git a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
index 83c236ea6b4a7d5..6fb475efcb6586f 100644
--- a/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
+++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir
@@ -394,14 +394,18 @@ func.func @regression_multiplication_with_sym(%A : memref<i64>) {
// -----
+
// Make sure min is transformed into zero.
-// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
-// CHECK: scf.index_switch %[[ZERO]] -> i1
+// CHECK-LABEL: func.func @func1()
+// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
+// CHECK: call @foo(%[[ZERO]]) : (index) -> ()
#map6 = affine_map<(d0, d1, d2) -> (d0 floordiv 64)>
#map29 = affine_map<(d0, d1, d2) -> (d2 * 64 - 2, 5, (d1 mod 4) floordiv 8)>
module {
+ func.func private @foo(%0 : index) -> ()
+
func.func @func1() {
%true = arith.constant true
%c0 = arith.constant 0 : index
@@ -412,11 +416,7 @@ module {
%alloc_249 = memref.alloc() : memref<7xf32>
%135 = affine.apply #map6(%c15, %c0, %c14)
%163 = affine.min #map29(%c5, %135, %c11)
- %196 = scf.index_switch %163 -> i1
- default {
- memref.assume_alignment %alloc_249, 1 : memref<7xf32>
- scf.yield %true : i1
- }
+ func.call @foo(%163) : (index) -> ()
return
}
}
More information about the Mlir-commits
mailing list