[Mlir-commits] [mlir] [MLIR] fix invalid `scf.index_switch` lowering to `cf.switch` when case values are large (PR #111590)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 8 14:24:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Keyi Zhang (Kuree)

<details>
<summary>Changes</summary>

This PR fixes https://github.com/llvm/llvm-project/issues/111589 by making sure `int64_t` is used when converting case values. Using `int32_t` may cause an overflow and result in an invalid IR, as shown in the issue. A test case is also added.

---
Full diff: https://github.com/llvm/llvm-project/pull/111590.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+6-3) 
- (modified) mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir (+19-2) 


``````````diff
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 45f3bcfa393be8..5b7b6713397048 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -669,7 +669,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
   // Convert the case regions.
   SmallVector<Block *> caseSuccessors;
-  SmallVector<int32_t> caseValues;
+  SmallVector<int64_t> caseValues;
   caseSuccessors.reserve(op.getCases().size());
   caseValues.reserve(op.getCases().size());
   for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
@@ -691,11 +691,14 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
 
   // Cast switch index to integer case value.
   Value caseValue = rewriter.create<arith::IndexCastOp>(
-      op.getLoc(), rewriter.getI32Type(), op.getArg());
+      op.getLoc(), rewriter.getI64Type(), op.getArg());
 
+  ShapedType caseValueType = VectorType::get(
+      static_cast<int64_t>(caseValues.size()), rewriter.getI64Type());
   rewriter.create<cf::SwitchOp>(
       op.getLoc(), caseValue, *defaultBlock, ValueRange(),
-      rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+      DenseIntElementsAttr::get(caseValueType, caseValues), caseSuccessors,
+      caseOperands);
   rewriter.replaceOp(op, continueBlock->getArguments());
   return success();
 }
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index 9ea0093eff7868..ba841313320194 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -622,8 +622,8 @@ func.func @func_execute_region_elim_multi_yield() {
 
 // CHECK-LABEL: @index_switch
 func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
-  // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
-  // CHECK: cf.switch %[[CASE]] : i32
+  // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
+  // CHECK: cf.switch %[[CASE]] : i64
   // CHECK-NEXT: default: ^[[DEFAULT:.+]],
   // CHECK-NEXT: 0: ^[[bb1:.+]],
   // CHECK-NEXT: 1: ^[[bb2:.+]]
@@ -648,6 +648,23 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
   return %0 : i32
 }
 
+// CHECK-LABEL: @index_switch_large_case
+func.func @index_switch_large_case(%i : index) {
+  // CHECK: cf.switch
+  // CHECK: 4294967296: ^[[bb1:.+]]
+  scf.index_switch %i
+  case 4294967296 { // 2^32
+    // CHECK: ^[[bb1]]:
+    // CHECK-NEXT: "test.op"
+    "test.op"() : () -> ()
+    scf.yield
+  }
+  default {
+    scf.yield
+  }
+  return
+}
+
 // Note: scf.forall is lowered to scf.parallel, which is currently lowered to
 // scf.for and then to unstructured control flow. scf.parallel could lower more
 // efficiently to multi-threaded IR, at which point scf.forall would

``````````

</details>


https://github.com/llvm/llvm-project/pull/111590


More information about the Mlir-commits mailing list