[flang-commits] [flang] 8c23990 - [fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Fri Nov 5 04:54:59 PDT 2021
Author: Valentin Clement
Date: 2021-11-05T12:54:51+01:00
New Revision: 8c239909495eafac9f9f765af70bbea660af6d6c
URL: https://github.com/llvm/llvm-project/commit/8c239909495eafac9f9f765af70bbea660af6d6c
DIFF: https://github.com/llvm/llvm-project/commit/8c239909495eafac9f9f765af70bbea660af6d6c.diff
LOG: [fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns
The `fir.select` and `fir.select_rank` are lowered to llvm.switch.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D113089
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/test/Fir/convert-to-llvm.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 1a6cd9b7f00fc..6d6f79c2d7adf 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -480,7 +480,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
// The number of destination conditions that may be tested
unsigned getNumConditions() {
- return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).size();
+ return getCases().size();
}
// The selector is the value being tested to determine the destination
@@ -488,6 +488,9 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
return operands[0];
}
+ mlir::Value getSelector(mlir::ValueRange operands) {
+ return operands.front();
+ }
// The number of blocks that may be branched to
unsigned getNumDest() { return (*this)->getNumSuccessors(); }
@@ -498,6 +501,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
+ llvm::Optional<mlir::ValueRange> getSuccessorOperands(
+ mlir::ValueRange operands, unsigned cond);
using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
@@ -510,6 +515,10 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
p.printSuccessor(succ);
}
+ mlir::ArrayAttr getCases() {
+ return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr());
+ }
+
unsigned targetOffsetSize();
}];
}
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index fb9e390a0f7a2..e18bdc2bdbf4b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -174,6 +174,78 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
}
};
+template <typename OP>
+void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
+ typename OP::Adaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) {
+ unsigned conds = select.getNumConditions();
+ auto cases = select.getCases().getValue();
+ mlir::Value selector = adaptor.selector();
+ auto loc = select.getLoc();
+ assert(conds > 0 && "select must have cases");
+
+ llvm::SmallVector<mlir::Block *> destinations;
+ llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+ mlir::Block *defaultDestination;
+ mlir::ValueRange defaultOperands;
+ llvm::SmallVector<int32_t> caseValues;
+
+ for (unsigned t = 0; t != conds; ++t) {
+ mlir::Block *dest = select.getSuccessor(t);
+ auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+ const mlir::Attribute &attr = cases[t];
+ if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
+ destinations.push_back(dest);
+ destinationsOperands.push_back(destOps.hasValue() ? *destOps
+ : ValueRange());
+ caseValues.push_back(intAttr.getInt());
+ continue;
+ }
+ assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
+ assert((t + 1 == conds) && "unit must be last");
+ defaultDestination = dest;
+ defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
+ }
+
+ // LLVM::SwitchOp takes a i32 type for the selector.
+ if (select.getSelector().getType() != rewriter.getI32Type())
+ selector =
+ rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+ select, selector,
+ /*defaultDestination=*/defaultDestination,
+ /*defaultOperands=*/defaultOperands,
+ /*caseValues=*/caseValues,
+ /*caseDestinations=*/destinations,
+ /*caseOperands=*/destinationsOperands,
+ /*branchWeights=*/ArrayRef<int32_t>());
+}
+
+/// conversion of fir::SelectOp to an if-then-else ladder
+struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
+ using FIROpConversion::FIROpConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
+ return success();
+ }
+};
+
+/// conversion of fir::SelectRankOp to an if-then-else ladder
+struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
+ using FIROpConversion::FIROpConversion;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
+ return success();
+ }
+};
+
// convert to LLVM IR dialect `undef`
struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
using FIROpConversion::FIROpConversion;
@@ -318,8 +390,9 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
- InsertOnRangeOpConversion, UndefOpConversion,
- UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+ InsertOnRangeOpConversion, SelectOpConversion,
+ SelectRankOpConversion, UnreachableOpConversion,
+ ZeroOpConversion, UndefOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bdeced9c7b617..62d04b30c6941 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2264,6 +2264,15 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
+llvm::Optional<mlir::ValueRange>
+fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) {
+ auto a =
+ (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
+ auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
+ getOperandSegmentSizeAttr());
+ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
+}
+
unsigned fir::SelectOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));
@@ -2557,6 +2566,16 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
+llvm::Optional<mlir::ValueRange>
+fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands,
+ unsigned oper) {
+ auto a =
+ (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
+ auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
+ getOperandSegmentSizeAttr());
+ return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
+}
+
unsigned fir::SelectRankOp::targetOffsetSize() {
return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
getTargetOffsetAttr()));
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a977dac869eab..b0a628678b562 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -167,3 +167,95 @@ func @zero_test_float() {
func @test_unreachable() {
fir.unreachable
}
+
+// -----
+
+// Test `fir.select` operation conversion pattern.
+// Check that the if-then-else ladder is correctly constructed and that we
+// branch to the correct block.
+
+func @select(%arg : index, %arg2 : i32) -> i32 {
+ %0 = arith.constant 1 : i32
+ %1 = arith.constant 2 : i32
+ %2 = arith.constant 3 : i32
+ %3 = arith.constant 4 : i32
+ fir.select %arg:index [ 1, ^bb1(%0:i32),
+ 2, ^bb2(%2,%arg,%arg2:i32,index,i32),
+ 3, ^bb3(%arg2,%2:i32,i32),
+ 4, ^bb4(%1:i32),
+ unit, ^bb5 ]
+ ^bb1(%a : i32) :
+ return %a : i32
+ ^bb2(%b : i32, %b2 : index, %b3:i32) :
+ %castidx = arith.index_cast %b2 : index to i32
+ %4 = arith.addi %b, %castidx : i32
+ %5 = arith.addi %4, %b3 : i32
+ return %5 : i32
+ ^bb3(%c:i32, %c2:i32) :
+ %6 = arith.addi %c, %c2 : i32
+ return %6 : i32
+ ^bb4(%d : i32) :
+ return %d : i32
+ ^bb5 :
+ %zero = arith.constant 0 : i32
+ return %zero : i32
+}
+
+// CHECK-LABEL: func @select(
+// CHECK-SAME: %[[SELECTVALUE:.*]]: [[IDX:.*]],
+// CHECK-SAME: %[[ARG1:.*]]: i32)
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
+// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [
+// CHECK: 1: ^bb1(%[[C0]] : i32),
+// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
+// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
+// CHECK: 4: ^bb4(%[[C1]] : i32)
+// CHECK: ]
+
+// -----
+
+// Test `fir.select_rank` operation conversion pattern.
+// Check that the if-then-else ladder is correctly constructed and that we
+// branch to the correct block.
+
+func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
+ %0 = arith.constant 1 : i32
+ %1 = arith.constant 2 : i32
+ %2 = arith.constant 3 : i32
+ %3 = arith.constant 4 : i32
+ fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32),
+ 2, ^bb2(%2,%arg,%arg2:i32,i32,i32),
+ 3, ^bb3(%arg2,%2:i32,i32),
+ 4, ^bb4(%1:i32),
+ unit, ^bb5 ]
+ ^bb1(%a : i32) :
+ return %a : i32
+ ^bb2(%b : i32, %b2 : i32, %b3:i32) :
+ %4 = arith.addi %b, %b2 : i32
+ %5 = arith.addi %4, %b3 : i32
+ return %5 : i32
+ ^bb3(%c:i32, %c2:i32) :
+ %6 = arith.addi %c, %c2 : i32
+ return %6 : i32
+ ^bb4(%d : i32) :
+ return %d : i32
+ ^bb5 :
+ %zero = arith.constant 0 : i32
+ return %zero : i32
+}
+
+// CHECK-LABEL: func @select_rank(
+// CHECK-SAME: %[[SELECTVALUE:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: i32)
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [
+// CHECK: 1: ^bb1(%[[C0]] : i32),
+// CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
+// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
+// CHECK: 4: ^bb4(%[[C1]] : i32)
+// CHECK: ]
More information about the flang-commits
mailing list