[Mlir-commits] [mlir] 88f07a3 - [mlir][llvm] Fix import of SwitchOp
Tobias Gysi
llvmlistbot at llvm.org
Wed Jun 21 23:41:27 PDT 2023
Author: Tobias Gysi
Date: 2023-06-22T06:27:32Z
New Revision: 88f07a311947f88de82ad2de9b2d6a26eba21343
URL: https://github.com/llvm/llvm-project/commit/88f07a311947f88de82ad2de9b2d6a26eba21343
DIFF: https://github.com/llvm/llvm-project/commit/88f07a311947f88de82ad2de9b2d6a26eba21343.diff
LOG: [mlir][llvm] Fix import of SwitchOp
This revision ensures SwitchOps with case and condition
bitwidths other than 32-bit are imported properly. It adds an
APInt based builder to the SwitchOp and implements
a verifier that checks that the condition and the case
value types match.
Reviewed By: Dinistro
Differential Revision: https://reviews.llvm.org/D153438
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/Import/control-flow.ll
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4081091a579d..a106592c9b5a2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -931,7 +931,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
AnyInteger:$value,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
- OptionalAttr<ElementsAttr>:$case_values,
+ OptionalAttr<AnyIntElementsAttr>:$case_values,
DenseI32ArrayAttr:$case_operand_segments,
OptionalAttr<ElementsAttr>:$branch_weights
);
@@ -950,6 +950,13 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
let hasVerifier = 1;
let builders = [
+ OpBuilder<(ins "Value":$value,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<APInt>", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
OpBuilder<(ins "Value":$value,
"Block *":$defaultDestination,
"ValueRange":$defaultOperands,
@@ -957,6 +964,13 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
CArg<"BlockRange", "{}">:$caseDestinations,
CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
+ OpBuilder<(ins "Value":$value,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"DenseIntElementsAttr", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
LLVM_TerminatorPassthroughOpBuilder
];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a97632c222287..7bbdb67d8c4a7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -355,25 +355,54 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result,
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
Block *defaultDestination, ValueRange defaultOperands,
- ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+ DenseIntElementsAttr caseValues,
+ BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
- ElementsAttr caseValuesAttr;
- if (!caseValues.empty())
- caseValuesAttr = builder.getI32VectorAttr(caseValues);
-
ElementsAttr weightsAttr;
if (!branchWeights.empty())
weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
- build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
+ build(builder, result, value, defaultOperands, caseOperands, caseValues,
weightsAttr, defaultDestination, caseDestinations);
}
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Block *defaultDestination, ValueRange defaultOperands,
+ ArrayRef<APInt> caseValues, BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands,
+ ArrayRef<int32_t> branchWeights) {
+ DenseIntElementsAttr caseValuesAttr;
+ if (!caseValues.empty()) {
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(caseValues.size()), value.getType());
+ caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+ }
+
+ build(builder, result, value, defaultDestination, defaultOperands,
+ caseValuesAttr, caseDestinations, caseOperands, branchWeights);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Block *defaultDestination, ValueRange defaultOperands,
+ ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands,
+ ArrayRef<int32_t> branchWeights) {
+ DenseIntElementsAttr caseValuesAttr;
+ if (!caseValues.empty()) {
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(caseValues.size()), value.getType());
+ caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+ }
+
+ build(builder, result, value, defaultDestination, defaultOperands,
+ caseValuesAttr, caseDestinations, caseOperands, branchWeights);
+}
+
/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
static ParseResult parseSwitchOpCases(
- OpAsmParser &parser, Type flagType, ElementsAttr &caseValues,
+ OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
@@ -412,7 +441,7 @@ static ParseResult parseSwitchOpCases(
}
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
- ElementsAttr caseValues,
+ DenseIntElementsAttr caseValues,
SuccessorRange caseDestinations,
OperandRangeRange caseOperands,
const TypeRangeRange &caseOperandTypes) {
@@ -421,7 +450,7 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
size_t index = 0;
llvm::interleave(
- llvm::zip(llvm::cast<DenseIntElementsAttr>(caseValues), caseDestinations),
+ llvm::zip(caseValues, caseDestinations),
[&](auto i) {
p << " ";
p << std::get<0>(i).getLimitedValue();
@@ -446,6 +475,9 @@ LogicalResult SwitchOp::verify() {
return emitError("expects number of branch weights to match number of "
"successors: ")
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
+ if (getCaseValues() &&
+ getValue().getType() != getCaseValues()->getElementType())
+ return emitError("expects case value type to match condition value type");
return success();
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index fbb20bed6e8ff..e8ffee33ea3bc 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1348,7 +1348,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unsigned numCases = swInst->getNumCases();
SmallVector<SmallVector<Value>> caseOperands(numCases);
SmallVector<ValueRange> caseOperandRefs(numCases);
- SmallVector<int32_t> caseValues(numCases);
+ SmallVector<APInt> caseValues(numCases);
SmallVector<Block *> caseBlocks(numCases);
for (const auto &it : llvm::enumerate(swInst->cases())) {
const llvm::SwitchInst::CaseHandle &caseHandle = it.value();
@@ -1356,7 +1356,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertBranchArgs(swInst, succBB, caseOperands[it.index()])))
return failure();
caseOperandRefs[it.index()] = caseOperands[it.index()];
- caseValues[it.index()] = caseHandle.getCaseValue()->getSExtValue();
+ caseValues[it.index()] = caseHandle.getCaseValue()->getValue();
caseBlocks[it.index()] = lookupBlock(succBB);
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index db8b304182805..787889f413b5c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -872,6 +872,17 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
// -----
+func.func @switch_case_type_mismatch(%arg0 : i64) {
+ // expected-error at below {{expects case value type to match condition value type}}
+ "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array<i32: 0>, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array<i32: 1, 0, 0>}> : (i64) -> ()
+^bb1: // pred: ^bb0
+ llvm.return
+^bb2: // pred: ^bb0
+ llvm.return
+}
+
+// -----
+
// expected-error at below {{expected zero value for 'common' linkage}}
llvm.mlir.global common @non_zero_global_common_linkage(42 : i32) : i32
diff --git a/mlir/test/Target/LLVMIR/Import/control-flow.ll b/mlir/test/Target/LLVMIR/Import/control-flow.ll
index 0cd7c418748c1..6c5d043c84acc 100644
--- a/mlir/test/Target/LLVMIR/Import/control-flow.ll
+++ b/mlir/test/Target/LLVMIR/Import/control-flow.ll
@@ -46,33 +46,33 @@ bb2:
; CHECK-LABEL: @simple_switch(
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-define i32 @simple_switch(i32 %arg1) {
+define i64 @simple_switch(i64 %arg1) {
; CHECK: %[[VAL1:.+]] = llvm.add
; CHECK: %[[VAL2:.+]] = llvm.sub
; CHECK: %[[VAL3:.+]] = llvm.mul
- %1 = add i32 %arg1, 42
- %2 = sub i32 %arg1, 42
- %3 = mul i32 %arg1, 42
- ; CHECK: llvm.switch %[[ARG1]] : i32, ^[[BBD:.+]] [
+ %1 = add i64 %arg1, 42
+ %2 = sub i64 %arg1, 42
+ %3 = mul i64 %arg1, 42
+ ; CHECK: llvm.switch %[[ARG1]] : i64, ^[[BBD:.+]] [
; CHECK: 0: ^[[BB1:.+]],
; CHECK: 9: ^[[BB2:.+]]
; CHECK: ]
- switch i32 %arg1, label %bbd [
- i32 0, label %bb1
- i32 9, label %bb2
+ switch i64 %arg1, label %bbd [
+ i64 0, label %bb1
+ i64 9, label %bb2
]
bb1:
; CHECK: ^[[BB1]]:
; CHECK: llvm.return %[[VAL1]]
- ret i32 %1
+ ret i64 %1
bb2:
; CHECK: ^[[BB2]]:
; CHECK: llvm.return %[[VAL2]]
- ret i32 %2
+ ret i64 %2
bbd:
; CHECK: ^[[BBD]]:
; CHECK: llvm.return %[[VAL3]]
- ret i32 %3
+ ret i64 %3
}
; // -----
More information about the Mlir-commits
mailing list