[Mlir-commits] [mlir] 30d87d4 - [MLIR][LLVM] Permit integer types in switch other than i32

William S. Moses llvmlistbot at llvm.org
Tue Nov 16 09:00:43 PST 2021


Author: William S. Moses
Date: 2021-11-16T12:00:37-05:00
New Revision: 30d87d4a5d02f00ef58ebc24a0ee5c6c370b8b4c

URL: https://github.com/llvm/llvm-project/commit/30d87d4a5d02f00ef58ebc24a0ee5c6c370b8b4c
DIFF: https://github.com/llvm/llvm-project/commit/30d87d4a5d02f00ef58ebc24a0ee5c6c370b8b4c.diff

LOG: [MLIR][LLVM] Permit integer types in switch other than i32

LLVM switchop currently only permits i32. Both LLVM IR and MLIR Standard switch permit other integer types leading to an illegal state when lowering an i8 switch from MLIR standard

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113955

Added: 
    

Modified: 
    flang/test/Fir/convert-to-llvm.fir
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 7cfd73c01bf4b..f597da6e4caf4 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -215,7 +215,7 @@ func @select(%arg : index, %arg2 : i32) -> i32 {
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
 // CHECK:         %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
-// CHECK:         llvm.switch %[[SELECTOR]], ^bb5 [
+// CHECK:         llvm.switch %[[SELECTOR]] : i32, ^bb5 [
 // CHECK:           1: ^bb1(%[[C0]] : i32),
 // CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
 // CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
@@ -260,7 +260,7 @@ func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
 // CHECK:         %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
-// CHECK:         llvm.switch %[[SELECTVALUE]], ^bb5 [
+// CHECK:         llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [
 // CHECK:           1: ^bb1(%[[C0]] : i32),
 // CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
 // CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6bd64edf44c4f..055975ef58bcf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -724,7 +724,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
      NoSideEffect]> {
   let arguments = (ins
-    I32:$value,
+    AnyInteger:$value,
     Variadic<AnyType>:$defaultOperands,
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
     OptionalAttr<ElementsAttr>:$case_values,
@@ -738,9 +738,9 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
 
   let verifier = [{ return ::verify(*this); }];
   let assemblyFormat = [{
-    $value `,`
+    $value `:` type($value) `,`
     $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
-    `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+    `[` `\n` custom<SwitchOpCases>(ref(type($value)), $case_values, $caseDestinations,
                                    $caseOperands, type($caseOperands)) `]`
     attr-dict
   }];

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index dd2ffdabd8c5f..0afc64b2ffce0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -269,20 +269,21 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
 static ParseResult parseSwitchOpCases(
-    OpAsmParser &parser, ElementsAttr &caseValues,
+    OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
     SmallVectorImpl<Block *> &caseDestinations,
     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
-  SmallVector<int32_t> values;
-  int32_t value = 0;
+  SmallVector<APInt> values;
+  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
   do {
+    int64_t value = 0;
     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
     if (values.empty() && !integerParseResult.hasValue())
       return success();
 
     if (!integerParseResult.hasValue() || integerParseResult.getValue())
       return failure();
-    values.push_back(value);
+    values.push_back(APInt(bitWidth, value));
 
     Block *destination;
     SmallVector<OpAsmParser::OperandType> operands;
@@ -299,11 +300,13 @@ static ParseResult parseSwitchOpCases(
     caseOperandTypes.emplace_back(operandTypes);
   } while (!parser.parseOptionalComma());
 
-  caseValues = parser.getBuilder().getI32VectorAttr(values);
+  ShapedType caseValueType =
+      VectorType::get(static_cast<int64_t>(values.size()), flagType);
+  caseValues = DenseIntElementsAttr::get(caseValueType, values);
   return success();
 }
 
-static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
                                ElementsAttr caseValues,
                                SuccessorRange caseDestinations,
                                OperandRangeRange caseOperands,

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
index 7f5500c875e7a..7e47448e7f3ba 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
@@ -64,7 +64,7 @@ func @coro_suspend() {
   // CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1
   // CHECK: %[[RET:.*]] = llvm.intr.coro.suspend %[[STATE]], %[[FINAL]]
   // CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32
-  // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+  // CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
   // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
   // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
   async.coro.suspend %2, ^suspend, ^resume, ^cleanup

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index eb8ddbb13e5d3..46ff7501f4b89 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -49,7 +49,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Decide the next block based on the code returned from suspend.
 // CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
-// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
 // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
 // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 9e9636a20ab83..7d0942ca8691b 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -592,3 +592,31 @@ func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : v
   %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
   std.return
 }
+
+// -----
+
+// CHECK-LABEL: func @switchi8(
+func @switchi8(%arg0 : i8) -> i32 {
+switch %arg0 : i8, [
+  default: ^bb1,
+    42: ^bb1,
+    43: ^bb3
+  ]
+^bb1:
+  %c_1 = arith.constant 1 : i32
+  std.return %c_1 : i32
+^bb3:
+  %c_42 = arith.constant 42 : i32
+  std.return %c_42: i32
+}
+// CHECK:     llvm.switch %arg0 : i8, ^bb1 [
+// CHECK-NEXT:       42: ^bb1,
+// CHECK-NEXT:       43: ^bb2
+// CHECK-NEXT:     ]
+// CHECK:   ^bb1:  // 2 preds: ^bb0, ^bb0
+// CHECK-NEXT:     %[[E0:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E0]] : i32
+// CHECK:   ^bb2:  // pred: ^bb0
+// CHECK-NEXT:     %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E1]] : i32
+// CHECK-NEXT:   }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 3f07f173ec875..fd9b5765fa2f2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -805,7 +805,7 @@ module attributes {llvm.data_layout = "#vjkr32"} {
 
 func @switch_wrong_number_of_weights(%arg0 : i32) {
   // expected-error at +1 {{expects number of branch weights to match number of successors: 3 vs 2}}
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     42: ^bb2(%arg0, %arg0 : i32, i32)
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
 

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 8efd14ee4f597..b931c9bb69e86 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -84,12 +84,12 @@ func @ops(%arg0: i32, %arg1: f32,
 // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : i47
   %22 = llvm.mlir.undef : !llvm.struct<(i32, f64, i32)>
   %23 = llvm.mlir.constant(42) : i47
-  // CHECK:      llvm.switch %0, ^[[BB3]] [
+  // CHECK:      llvm.switch %0 : i32, ^[[BB3]] [
   // CHECK-NEXT:   1: ^[[BB4:.*]],
   // CHECK-NEXT:   2: ^[[BB5:.*]],
   // CHECK-NEXT:   3: ^[[BB6:.*]]
   // CHECK-NEXT: ]
-  llvm.switch %0, ^bb3 [
+  llvm.switch %0 : i32, ^bb3 [
     1: ^bb4,
     2: ^bb5,
     3: ^bb6
@@ -97,24 +97,24 @@ func @ops(%arg0: i32, %arg1: f32,
 
 // CHECK: ^[[BB3]]
 ^bb3:
-// CHECK:      llvm.switch %0, ^[[BB7:.*]] [
+// CHECK:      llvm.switch %0 : i32, ^[[BB7:.*]] [
 // CHECK-NEXT: ]
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB4]]
 ^bb4:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB5]]
 ^bb5:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB6]]
 ^bb6:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB7]]

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7677e59c3cd91..f5b6d60662ad2 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1560,7 +1560,7 @@ llvm.func @switch_args(%arg0: i32) -> i32 {
   // CHECK-NEXT:   i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
   // CHECK-NEXT:   i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
   // CHECK-NEXT: ]
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     -1: ^bb2(%0 : i32),
     1: ^bb3(%1, %2 : i32, i32)
   ]
@@ -1590,7 +1590,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
   %1 = llvm.mlir.constant(23 : i32) : i32
   %2 = llvm.mlir.constant(29 : i32) : i32
   // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
-  llvm.switch %arg0, ^bb1(%0 : i32) [
+  llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
     9: ^bb2(%1, %2 : i32, i32),
     99: ^bb3
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}


        


More information about the Mlir-commits mailing list