[Mlir-commits] [mlir] [MLIR] fix invalid `scf.index_switch` lowering to `cf.switch` when case values are large (PR #111590)
Keyi Zhang
llvmlistbot at llvm.org
Tue Oct 8 14:24:15 PDT 2024
https://github.com/Kuree created https://github.com/llvm/llvm-project/pull/111590
This PR fixes https://github.com/llvm/llvm-project/issues/111589 by making sure `int64_t` is used when converting case values. Using `int32_t` may cause an overflow and result in an invalid IR, as shown in the issue. A test case is also added.
>From 0b942c5b04aedaa3ba58663e9eebbe4c19fac86d Mon Sep 17 00:00:00 2001
From: Keyi Zhang <keyi at efficient.computer>
Date: Tue, 8 Oct 2024 14:17:07 -0700
Subject: [PATCH] [MLIR] fix invalid scf.index_switch lowering
---
.../SCFToControlFlow/SCFToControlFlow.cpp | 9 +++++---
.../SCFToControlFlow/convert-to-cfg.mlir | 21 +++++++++++++++++--
2 files changed, 25 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 45f3bcfa393be8..5b7b6713397048 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -669,7 +669,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
// Convert the case regions.
SmallVector<Block *> caseSuccessors;
- SmallVector<int32_t> caseValues;
+ SmallVector<int64_t> caseValues;
caseSuccessors.reserve(op.getCases().size());
caseValues.reserve(op.getCases().size());
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
@@ -691,11 +691,14 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
// Cast switch index to integer case value.
Value caseValue = rewriter.create<arith::IndexCastOp>(
- op.getLoc(), rewriter.getI32Type(), op.getArg());
+ op.getLoc(), rewriter.getI64Type(), op.getArg());
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(caseValues.size()), rewriter.getI64Type());
rewriter.create<cf::SwitchOp>(
op.getLoc(), caseValue, *defaultBlock, ValueRange(),
- rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ DenseIntElementsAttr::get(caseValueType, caseValues), caseSuccessors,
+ caseOperands);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff7868..ba841313320194 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -622,8 +622,8 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK-LABEL: @index_switch
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
- // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
- // CHECK: cf.switch %[[CASE]] : i32
+ // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
+ // CHECK: cf.switch %[[CASE]] : i64
// CHECK-NEXT: default: ^[[DEFAULT:.+]],
// CHECK-NEXT: 0: ^[[bb1:.+]],
// CHECK-NEXT: 1: ^[[bb2:.+]]
@@ -648,6 +648,23 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
return %0 : i32
}
+// CHECK-LABEL: @index_switch_large_case
+func.func @index_switch_large_case(%i : index) {
+ // CHECK: cf.switch
+ // CHECK: 4294967296: ^[[bb1:.+]]
+ scf.index_switch %i
+ case 4294967296 { // 2^32
+ // CHECK: ^[[bb1]]:
+ // CHECK-NEXT: "test.op"
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ default {
+ scf.yield
+ }
+ return
+}
+
// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
// scf.for and then to unstructured control flow. scf.parallel could lower more
// efficiently to multi-threaded IR, at which point scf.forall would
More information about the Mlir-commits
mailing list