[Mlir-commits] [mlir] 9499267 - [MLIR][LLVM] Permit integer types in switch other than i32

William S. Moses llvmlistbot at llvm.org
Mon Nov 15 21:46:30 PST 2021


Author: William S. Moses
Date: 2021-11-16T00:46:25-05:00
New Revision: 94992670fcc59d12d7f97cb08beb8d2eb15110ed

URL: https://github.com/llvm/llvm-project/commit/94992670fcc59d12d7f97cb08beb8d2eb15110ed
DIFF: https://github.com/llvm/llvm-project/commit/94992670fcc59d12d7f97cb08beb8d2eb15110ed.diff

LOG: [MLIR][LLVM] Permit integer types in switch other than i32

LLVM switchop currently only permits i32. Both LLVM IR and MLIR Standard switch permit other integer types leading to an illegal state when lowering an i8 switch from MLIR standard

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113955

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
    mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6bd64edf44c4f..055975ef58bcf 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -724,7 +724,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
      NoSideEffect]> {
   let arguments = (ins
-    I32:$value,
+    AnyInteger:$value,
     Variadic<AnyType>:$defaultOperands,
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
     OptionalAttr<ElementsAttr>:$case_values,
@@ -738,9 +738,9 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
 
   let verifier = [{ return ::verify(*this); }];
   let assemblyFormat = [{
-    $value `,`
+    $value `:` type($value) `,`
     $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
-    `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+    `[` `\n` custom<SwitchOpCases>(ref(type($value)), $case_values, $caseDestinations,
                                    $caseOperands, type($caseOperands)) `]`
     attr-dict
   }];

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index ba7ec4f283af2..2587dbd10701f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -269,20 +269,21 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
 static ParseResult parseSwitchOpCases(
-    OpAsmParser &parser, ElementsAttr &caseValues,
+    OpAsmParser &parser, Type &flagType, ElementsAttr &caseValues,
     SmallVectorImpl<Block *> &caseDestinations,
     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
-  SmallVector<int32_t> values;
-  int32_t value = 0;
+  SmallVector<APInt> values;
+  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
   do {
+    int64_t value = 0;
     OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
     if (values.empty() && !integerParseResult.hasValue())
       return success();
 
     if (!integerParseResult.hasValue() || integerParseResult.getValue())
       return failure();
-    values.push_back(value);
+    values.push_back(APInt(bitWidth, value));
 
     Block *destination;
     SmallVector<OpAsmParser::OperandType> operands;
@@ -299,11 +300,14 @@ static ParseResult parseSwitchOpCases(
     caseOperandTypes.emplace_back(operandTypes);
   } while (!parser.parseOptionalComma());
 
-  caseValues = parser.getBuilder().getI32VectorAttr(values);
+  ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(values.size()), flagType); 
+  caseValues = DenseIntElementsAttr::get(caseValueType, values);
   return success();
 }
 
 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+                               Type &flagType,
                                ElementsAttr caseValues,
                                SuccessorRange caseDestinations,
                                OperandRangeRange caseOperands,

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
index 7f5500c875e7a..7e47448e7f3ba 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir
@@ -64,7 +64,7 @@ func @coro_suspend() {
   // CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1
   // CHECK: %[[RET:.*]] = llvm.intr.coro.suspend %[[STATE]], %[[FINAL]]
   // CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32
-  // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+  // CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
   // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
   // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
   async.coro.suspend %2, ^suspend, ^resume, ^cleanup

diff  --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
index eb8ddbb13e5d3..46ff7501f4b89 100644
--- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir
@@ -49,7 +49,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
 
 // Decide the next block based on the code returned from suspend.
 // CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32
-// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]]
+// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]]
 // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]]
 // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]]
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index 9e9636a20ab83..7d0942ca8691b 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -592,3 +592,31 @@ func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : v
   %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32>
   std.return
 }
+
+// -----
+
+// CHECK-LABEL: func @switchi8(
+func @switchi8(%arg0 : i8) -> i32 {
+switch %arg0 : i8, [
+  default: ^bb1,
+    42: ^bb1,
+    43: ^bb3
+  ]
+^bb1:
+  %c_1 = arith.constant 1 : i32
+  std.return %c_1 : i32
+^bb3:
+  %c_42 = arith.constant 42 : i32
+  std.return %c_42: i32
+}
+// CHECK:     llvm.switch %arg0 : i8, ^bb1 [
+// CHECK-NEXT:       42: ^bb1,
+// CHECK-NEXT:       43: ^bb2
+// CHECK-NEXT:     ]
+// CHECK:   ^bb1:  // 2 preds: ^bb0, ^bb0
+// CHECK-NEXT:     %[[E0:.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E0]] : i32
+// CHECK:   ^bb2:  // pred: ^bb0
+// CHECK-NEXT:     %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32
+// CHECK-NEXT:     llvm.return %[[E1]] : i32
+// CHECK-NEXT:   }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 3f07f173ec875..fd9b5765fa2f2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -805,7 +805,7 @@ module attributes {llvm.data_layout = "#vjkr32"} {
 
 func @switch_wrong_number_of_weights(%arg0 : i32) {
   // expected-error at +1 {{expects number of branch weights to match number of successors: 3 vs 2}}
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     42: ^bb2(%arg0, %arg0 : i32, i32)
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
 

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 8efd14ee4f597..b931c9bb69e86 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -84,12 +84,12 @@ func @ops(%arg0: i32, %arg1: f32,
 // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : i47
   %22 = llvm.mlir.undef : !llvm.struct<(i32, f64, i32)>
   %23 = llvm.mlir.constant(42) : i47
-  // CHECK:      llvm.switch %0, ^[[BB3]] [
+  // CHECK:      llvm.switch %0 : i32, ^[[BB3]] [
   // CHECK-NEXT:   1: ^[[BB4:.*]],
   // CHECK-NEXT:   2: ^[[BB5:.*]],
   // CHECK-NEXT:   3: ^[[BB6:.*]]
   // CHECK-NEXT: ]
-  llvm.switch %0, ^bb3 [
+  llvm.switch %0 : i32, ^bb3 [
     1: ^bb4,
     2: ^bb5,
     3: ^bb6
@@ -97,24 +97,24 @@ func @ops(%arg0: i32, %arg1: f32,
 
 // CHECK: ^[[BB3]]
 ^bb3:
-// CHECK:      llvm.switch %0, ^[[BB7:.*]] [
+// CHECK:      llvm.switch %0 : i32, ^[[BB7:.*]] [
 // CHECK-NEXT: ]
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB4]]
 ^bb4:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB5]]
 ^bb5:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB6]]
 ^bb6:
-  llvm.switch %0, ^bb7 [
+  llvm.switch %0 : i32, ^bb7 [
   ]
 
 // CHECK: ^[[BB7]]

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7677e59c3cd91..f5b6d60662ad2 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1560,7 +1560,7 @@ llvm.func @switch_args(%arg0: i32) -> i32 {
   // CHECK-NEXT:   i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
   // CHECK-NEXT:   i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
   // CHECK-NEXT: ]
-  llvm.switch %arg0, ^bb1 [
+  llvm.switch %arg0 : i32, ^bb1 [
     -1: ^bb2(%0 : i32),
     1: ^bb3(%1, %2 : i32, i32)
   ]
@@ -1590,7 +1590,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
   %1 = llvm.mlir.constant(23 : i32) : i32
   %2 = llvm.mlir.constant(29 : i32) : i32
   // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
-  llvm.switch %arg0, ^bb1(%0 : i32) [
+  llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
     9: ^bb2(%1, %2 : i32, i32),
     99: ^bb3
   ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}


        


More information about the Mlir-commits mailing list