[Mlir-commits] [mlir] 88f07a3 - [mlir][llvm] Fix import of SwitchOp

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 21 23:41:27 PDT 2023


Author: Tobias Gysi
Date: 2023-06-22T06:27:32Z
New Revision: 88f07a311947f88de82ad2de9b2d6a26eba21343

URL: https://github.com/llvm/llvm-project/commit/88f07a311947f88de82ad2de9b2d6a26eba21343
DIFF: https://github.com/llvm/llvm-project/commit/88f07a311947f88de82ad2de9b2d6a26eba21343.diff

LOG: [mlir][llvm] Fix import of SwitchOp

This revision ensures SwitchOps with case and condition
bitwidths other than 32-bit are imported properly. It adds an
APInt based builder to the SwitchOp and implements
a verifier that checks that the condition and the case
value types match.

Reviewed By: Dinistro

Differential Revision: https://reviews.llvm.org/D153438

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/Import/control-flow.ll

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4081091a579d..a106592c9b5a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -931,7 +931,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     AnyInteger:$value,
     Variadic<AnyType>:$defaultOperands,
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
-    OptionalAttr<ElementsAttr>:$case_values,
+    OptionalAttr<AnyIntElementsAttr>:$case_values,
     DenseI32ArrayAttr:$case_operand_segments,
     OptionalAttr<ElementsAttr>:$branch_weights
   );
@@ -950,6 +950,13 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
   let hasVerifier = 1;
 
   let builders = [
+    OpBuilder<(ins "Value":$value,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<APInt>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
     OpBuilder<(ins "Value":$value,
       "Block *":$defaultDestination,
       "ValueRange":$defaultOperands,
@@ -957,6 +964,13 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
       CArg<"BlockRange", "{}">:$caseDestinations,
       CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
       CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
+    OpBuilder<(ins "Value":$value,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"DenseIntElementsAttr", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
     LLVM_TerminatorPassthroughOpBuilder
   ];
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a97632c222287..7bbdb67d8c4a7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -355,25 +355,54 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result,
 
 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      Block *defaultDestination, ValueRange defaultOperands,
-                     ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+                     DenseIntElementsAttr caseValues,
+                     BlockRange caseDestinations,
                      ArrayRef<ValueRange> caseOperands,
                      ArrayRef<int32_t> branchWeights) {
-  ElementsAttr caseValuesAttr;
-  if (!caseValues.empty())
-    caseValuesAttr = builder.getI32VectorAttr(caseValues);
-
   ElementsAttr weightsAttr;
   if (!branchWeights.empty())
     weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
 
-  build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
+  build(builder, result, value, defaultOperands, caseOperands, caseValues,
         weightsAttr, defaultDestination, caseDestinations);
 }
 
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     ArrayRef<APInt> caseValues, BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands,
+                     ArrayRef<int32_t> branchWeights) {
+  DenseIntElementsAttr caseValuesAttr;
+  if (!caseValues.empty()) {
+    ShapedType caseValueType = VectorType::get(
+        static_cast<int64_t>(caseValues.size()), value.getType());
+    caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+  }
+
+  build(builder, result, value, defaultDestination, defaultOperands,
+        caseValuesAttr, caseDestinations, caseOperands, branchWeights);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands,
+                     ArrayRef<int32_t> branchWeights) {
+  DenseIntElementsAttr caseValuesAttr;
+  if (!caseValues.empty()) {
+    ShapedType caseValueType = VectorType::get(
+        static_cast<int64_t>(caseValues.size()), value.getType());
+    caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+  }
+
+  build(builder, result, value, defaultDestination, defaultOperands,
+        caseValuesAttr, caseDestinations, caseOperands, branchWeights);
+}
+
 /// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
 ///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
 static ParseResult parseSwitchOpCases(
-    OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
+    OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
     SmallVectorImpl<Block *> &caseDestinations,
     SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
@@ -412,7 +441,7 @@ static ParseResult parseSwitchOpCases(
 }
 
 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
-                               ElementsAttr caseValues,
+                               DenseIntElementsAttr caseValues,
                                SuccessorRange caseDestinations,
                                OperandRangeRange caseOperands,
                                const TypeRangeRange &caseOperandTypes) {
@@ -421,7 +450,7 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
 
   size_t index = 0;
   llvm::interleave(
-      llvm::zip(llvm::cast<DenseIntElementsAttr>(caseValues), caseDestinations),
+      llvm::zip(caseValues, caseDestinations),
       [&](auto i) {
         p << "  ";
         p << std::get<0>(i).getLimitedValue();
@@ -446,6 +475,9 @@ LogicalResult SwitchOp::verify() {
     return emitError("expects number of branch weights to match number of "
                      "successors: ")
            << getBranchWeights()->size() << " vs " << getNumSuccessors();
+  if (getCaseValues() &&
+      getValue().getType() != getCaseValues()->getElementType())
+    return emitError("expects case value type to match condition value type");
   return success();
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index fbb20bed6e8ff..e8ffee33ea3bc 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1348,7 +1348,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     unsigned numCases = swInst->getNumCases();
     SmallVector<SmallVector<Value>> caseOperands(numCases);
     SmallVector<ValueRange> caseOperandRefs(numCases);
-    SmallVector<int32_t> caseValues(numCases);
+    SmallVector<APInt> caseValues(numCases);
     SmallVector<Block *> caseBlocks(numCases);
     for (const auto &it : llvm::enumerate(swInst->cases())) {
       const llvm::SwitchInst::CaseHandle &caseHandle = it.value();
@@ -1356,7 +1356,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       if (failed(convertBranchArgs(swInst, succBB, caseOperands[it.index()])))
         return failure();
       caseOperandRefs[it.index()] = caseOperands[it.index()];
-      caseValues[it.index()] = caseHandle.getCaseValue()->getSExtValue();
+      caseValues[it.index()] = caseHandle.getCaseValue()->getValue();
       caseBlocks[it.index()] = lookupBlock(succBB);
     }
 

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index db8b304182805..787889f413b5c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -872,6 +872,17 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
 
 // -----
 
+func.func @switch_case_type_mismatch(%arg0 : i64) {
+  // expected-error at below {{expects case value type to match condition value type}}
+  "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array<i32: 0>, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array<i32: 1, 0, 0>}> : (i64) -> ()
+^bb1: // pred: ^bb0
+  llvm.return
+^bb2: // pred: ^bb0
+  llvm.return
+}
+
+// -----
+
 // expected-error at below {{expected zero value for 'common' linkage}}
 llvm.mlir.global common @non_zero_global_common_linkage(42 : i32) : i32
 

diff  --git a/mlir/test/Target/LLVMIR/Import/control-flow.ll b/mlir/test/Target/LLVMIR/Import/control-flow.ll
index 0cd7c418748c1..6c5d043c84acc 100644
--- a/mlir/test/Target/LLVMIR/Import/control-flow.ll
+++ b/mlir/test/Target/LLVMIR/Import/control-flow.ll
@@ -46,33 +46,33 @@ bb2:
 
 ; CHECK-LABEL: @simple_switch(
 ; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
-define i32 @simple_switch(i32 %arg1) {
+define i64 @simple_switch(i64 %arg1) {
   ; CHECK: %[[VAL1:.+]] = llvm.add
   ; CHECK: %[[VAL2:.+]] = llvm.sub
   ; CHECK: %[[VAL3:.+]] = llvm.mul
-  %1 = add i32 %arg1, 42
-  %2 = sub i32 %arg1, 42
-  %3 = mul i32 %arg1, 42
-  ; CHECK: llvm.switch %[[ARG1]] : i32, ^[[BBD:.+]] [
+  %1 = add i64 %arg1, 42
+  %2 = sub i64 %arg1, 42
+  %3 = mul i64 %arg1, 42
+  ; CHECK: llvm.switch %[[ARG1]] : i64, ^[[BBD:.+]] [
   ; CHECK:   0: ^[[BB1:.+]],
   ; CHECK:   9: ^[[BB2:.+]]
   ; CHECK: ]
-  switch i32 %arg1, label %bbd [
-    i32 0, label %bb1
-    i32 9, label %bb2
+  switch i64 %arg1, label %bbd [
+    i64 0, label %bb1
+    i64 9, label %bb2
   ]
 bb1:
   ; CHECK: ^[[BB1]]:
   ; CHECK: llvm.return %[[VAL1]]
-  ret i32 %1
+  ret i64 %1
 bb2:
   ; CHECK: ^[[BB2]]:
   ; CHECK: llvm.return %[[VAL2]]
-  ret i32 %2
+  ret i64 %2
 bbd:
   ; CHECK: ^[[BBD]]:
   ; CHECK: llvm.return %[[VAL3]]
-  ret i32 %3
+  ret i64 %3
 }
 
 ; // -----


        


More information about the Mlir-commits mailing list