[flang-commits] [flang] 8c23990 - [fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Nov 5 04:54:59 PDT 2021


Author: Valentin Clement
Date: 2021-11-05T12:54:51+01:00
New Revision: 8c239909495eafac9f9f765af70bbea660af6d6c

URL: https://github.com/llvm/llvm-project/commit/8c239909495eafac9f9f765af70bbea660af6d6c
DIFF: https://github.com/llvm/llvm-project/commit/8c239909495eafac9f9f765af70bbea660af6d6c.diff

LOG: [fir] Add fir.select and fir.select_rank FIR to LLVM IR conversion patterns

The `fir.select` and `fir.select_rank` are lowered to llvm.switch.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D113089

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 1a6cd9b7f00fc..6d6f79c2d7adf 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -480,7 +480,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
 
     // The number of destination conditions that may be tested
     unsigned getNumConditions() {
-      return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).size();
+      return getCases().size();
     }
 
     // The selector is the value being tested to determine the destination
@@ -488,6 +488,9 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
     mlir::Value getSelector(llvm::ArrayRef<mlir::Value> operands) {
       return operands[0];
     }
+    mlir::Value getSelector(mlir::ValueRange operands) {
+      return operands.front();
+    }
 
     // The number of blocks that may be branched to
     unsigned getNumDest() { return (*this)->getNumSuccessors(); }
@@ -498,6 +501,8 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
 
     llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
         llvm::ArrayRef<mlir::Value> operands, unsigned cond);
+    llvm::Optional<mlir::ValueRange> getSuccessorOperands(
+        mlir::ValueRange operands, unsigned cond);
     using BranchOpInterfaceTrait::getSuccessorOperands;
 
     // Helper function to deal with Optional operand forms
@@ -510,6 +515,10 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
         p.printSuccessor(succ);
     }
 
+    mlir::ArrayAttr getCases() {
+      return (*this)->getAttrOfType<mlir::ArrayAttr>(getCasesAttr());
+    }
+
     unsigned targetOffsetSize();
   }];
 }

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index fb9e390a0f7a2..e18bdc2bdbf4b 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -174,6 +174,78 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
   }
 };
 
+template <typename OP>
+void selectMatchAndRewrite(fir::LLVMTypeConverter &lowering, OP select,
+                           typename OP::Adaptor adaptor,
+                           mlir::ConversionPatternRewriter &rewriter) {
+  unsigned conds = select.getNumConditions();
+  auto cases = select.getCases().getValue();
+  mlir::Value selector = adaptor.selector();
+  auto loc = select.getLoc();
+  assert(conds > 0 && "select must have cases");
+
+  llvm::SmallVector<mlir::Block *> destinations;
+  llvm::SmallVector<mlir::ValueRange> destinationsOperands;
+  mlir::Block *defaultDestination;
+  mlir::ValueRange defaultOperands;
+  llvm::SmallVector<int32_t> caseValues;
+
+  for (unsigned t = 0; t != conds; ++t) {
+    mlir::Block *dest = select.getSuccessor(t);
+    auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
+    const mlir::Attribute &attr = cases[t];
+    if (auto intAttr = attr.template dyn_cast<mlir::IntegerAttr>()) {
+      destinations.push_back(dest);
+      destinationsOperands.push_back(destOps.hasValue() ? *destOps
+                                                        : ValueRange());
+      caseValues.push_back(intAttr.getInt());
+      continue;
+    }
+    assert(attr.template dyn_cast_or_null<mlir::UnitAttr>());
+    assert((t + 1 == conds) && "unit must be last");
+    defaultDestination = dest;
+    defaultOperands = destOps.hasValue() ? *destOps : ValueRange();
+  }
+
+  // LLVM::SwitchOp takes a i32 type for the selector.
+  if (select.getSelector().getType() != rewriter.getI32Type())
+    selector =
+        rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), selector);
+
+  rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
+      select, selector,
+      /*defaultDestination=*/defaultDestination,
+      /*defaultOperands=*/defaultOperands,
+      /*caseValues=*/caseValues,
+      /*caseDestinations=*/destinations,
+      /*caseOperands=*/destinationsOperands,
+      /*branchWeights=*/ArrayRef<int32_t>());
+}
+
+/// conversion of fir::SelectOp to an if-then-else ladder
+struct SelectOpConversion : public FIROpConversion<fir::SelectOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
+    return success();
+  }
+};
+
+/// conversion of fir::SelectRankOp to an if-then-else ladder
+struct SelectRankOpConversion : public FIROpConversion<fir::SelectRankOp> {
+  using FIROpConversion::FIROpConversion;
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
+    return success();
+  }
+};
+
 // convert to LLVM IR dialect `undef`
 struct UndefOpConversion : public FIROpConversion<fir::UndefOp> {
   using FIROpConversion::FIROpConversion;
@@ -318,8 +390,9 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
     fir::LLVMTypeConverter typeConverter{getModule()};
     mlir::OwningRewritePatternList pattern(context);
     pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
-                   InsertOnRangeOpConversion, UndefOpConversion,
-                   UnreachableOpConversion, ZeroOpConversion>(typeConverter);
+                   InsertOnRangeOpConversion, SelectOpConversion,
+                   SelectRankOpConversion, UnreachableOpConversion,
+                   ZeroOpConversion, UndefOpConversion>(typeConverter);
     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
                                                             pattern);

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index bdeced9c7b617..62d04b30c6941 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2264,6 +2264,15 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
+llvm::Optional<mlir::ValueRange>
+fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) {
+  auto a =
+      (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
+  auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
+      getOperandSegmentSizeAttr());
+  return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
+}
+
 unsigned fir::SelectOp::targetOffsetSize() {
   return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
       getTargetOffsetAttr()));
@@ -2557,6 +2566,16 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
+llvm::Optional<mlir::ValueRange>
+fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands,
+                                        unsigned oper) {
+  auto a =
+      (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
+  auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
+      getOperandSegmentSizeAttr());
+  return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
+}
+
 unsigned fir::SelectRankOp::targetOffsetSize() {
   return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>(
       getTargetOffsetAttr()));

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a977dac869eab..b0a628678b562 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -167,3 +167,95 @@ func @zero_test_float() {
 func @test_unreachable() {
   fir.unreachable
 }
+
+// -----
+
+// Test `fir.select` operation conversion pattern.
+// Check that the if-then-else ladder is correctly constructed and that we
+// branch to the correct block.
+
+func @select(%arg : index, %arg2 : i32) -> i32 {
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 2 : i32
+  %2 = arith.constant 3 : i32
+  %3 = arith.constant 4 : i32
+  fir.select %arg:index [ 1, ^bb1(%0:i32),
+                          2, ^bb2(%2,%arg,%arg2:i32,index,i32),
+                          3, ^bb3(%arg2,%2:i32,i32),
+                          4, ^bb4(%1:i32),
+                          unit, ^bb5 ]
+  ^bb1(%a : i32) :
+    return %a : i32
+  ^bb2(%b : i32, %b2 : index, %b3:i32) :
+    %castidx = arith.index_cast %b2 : index to i32
+    %4 = arith.addi %b, %castidx : i32
+    %5 = arith.addi %4, %b3 : i32
+    return %5 : i32
+  ^bb3(%c:i32, %c2:i32) :
+    %6 = arith.addi %c, %c2 : i32
+    return %6 : i32
+  ^bb4(%d : i32) :
+    return %d : i32
+  ^bb5 :
+    %zero = arith.constant 0 : i32
+    return %zero : i32
+}
+
+// CHECK-LABEL: func @select(
+// CHECK-SAME:               %[[SELECTVALUE:.*]]: [[IDX:.*]],
+// CHECK-SAME:               %[[ARG1:.*]]: i32)
+// CHECK:         %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK:         %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32
+// CHECK:         llvm.switch %[[SELECTOR]], ^bb5 [
+// CHECK:           1: ^bb1(%[[C0]] : i32),
+// CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32),
+// CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
+// CHECK:           4: ^bb4(%[[C1]] : i32)
+// CHECK:         ]
+
+// -----
+
+// Test `fir.select_rank` operation conversion pattern.
+// Check that the if-then-else ladder is correctly constructed and that we
+// branch to the correct block.
+
+func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
+  %0 = arith.constant 1 : i32
+  %1 = arith.constant 2 : i32
+  %2 = arith.constant 3 : i32
+  %3 = arith.constant 4 : i32
+  fir.select_rank %arg:i32 [ 1, ^bb1(%0:i32),
+                             2, ^bb2(%2,%arg,%arg2:i32,i32,i32),
+                             3, ^bb3(%arg2,%2:i32,i32),
+                             4, ^bb4(%1:i32),
+                             unit, ^bb5 ]
+  ^bb1(%a : i32) :
+    return %a : i32
+  ^bb2(%b : i32, %b2 : i32, %b3:i32) :
+    %4 = arith.addi %b, %b2 : i32
+    %5 = arith.addi %4, %b3 : i32
+    return %5 : i32
+  ^bb3(%c:i32, %c2:i32) :
+    %6 = arith.addi %c, %c2 : i32
+    return %6 : i32
+  ^bb4(%d : i32) :
+    return %d : i32
+  ^bb5 :
+    %zero = arith.constant 0 : i32
+    return %zero : i32
+}
+
+// CHECK-LABEL: func @select_rank(
+// CHECK-SAME:                    %[[SELECTVALUE:.*]]: i32,
+// CHECK-SAME:                    %[[ARG1:.*]]: i32)
+// CHECK:         %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:         %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32
+// CHECK:         %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32
+// CHECK:         llvm.switch %[[SELECTVALUE]], ^bb5 [
+// CHECK:           1: ^bb1(%[[C0]] : i32),
+// CHECK:           2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32),
+// CHECK:           3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
+// CHECK:           4: ^bb4(%[[C1]] : i32)
+// CHECK:         ]


        


More information about the flang-commits mailing list