[flang-commits] [flang] 3ae8e44 - [fir] Add fir.insert_on_range conversion

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Nov 4 02:36:26 PDT 2021


Author: Valentin Clement
Date: 2021-11-04T10:36:20+01:00
New Revision: 3ae8e44215e4cf8272597ae39872229eb8934ebf

URL: https://github.com/llvm/llvm-project/commit/3ae8e44215e4cf8272597ae39872229eb8934ebf
DIFF: https://github.com/llvm/llvm-project/commit/3ae8e44215e4cf8272597ae39872229eb8934ebf.diff

LOG: [fir] Add fir.insert_on_range conversion

Convert fir.insert_on_range operation to corresponding
llvm.insertvalue operations.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D112896

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Fir/convert-to-llvm.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d7b0f819e9da4..fb9e390a0f7a2 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/ADT/ArrayRef.h"
 
@@ -45,6 +46,25 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
   }
 };
 
+/// FIR conversion pattern template
+template <typename FromOp>
+class FIROpAndTypeConversion : public FIROpConversion<FromOp> {
+public:
+  using FIROpConversion<FromOp>::FIROpConversion;
+  using OpAdaptor = typename FromOp::Adaptor;
+
+  mlir::LogicalResult
+  matchAndRewrite(FromOp op, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const final {
+    mlir::Type ty = this->convertType(op.getType());
+    return doRewrite(op, ty, adaptor, rewriter);
+  }
+
+  virtual mlir::LogicalResult
+  doRewrite(FromOp addr, mlir::Type ty, OpAdaptor adaptor,
+            mlir::ConversionPatternRewriter &rewriter) const = 0;
+};
+
 // Lower `fir.address_of` operation to `llvm.address_of` operation.
 struct AddrOfOpConversion : public FIROpConversion<fir::AddrOfOp> {
   using FIROpConversion::FIROpConversion;
@@ -204,6 +224,82 @@ struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
   }
 };
 
+/// InsertOnRange inserts a value into a sequence over a range of offsets.
+struct InsertOnRangeOpConversion
+    : public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
+  using FIROpAndTypeConversion::FIROpAndTypeConversion;
+
+  // Increments an array of subscripts in a row major fasion.
+  void incrementSubscripts(const SmallVector<uint64_t> &dims,
+                           SmallVector<uint64_t> &subscripts) const {
+    for (size_t i = dims.size(); i > 0; --i) {
+      if (++subscripts[i - 1] < dims[i - 1]) {
+        return;
+      }
+      subscripts[i - 1] = 0;
+    }
+  }
+
+  mlir::LogicalResult
+  doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
+            mlir::ConversionPatternRewriter &rewriter) const override {
+
+    llvm::SmallVector<uint64_t> dims;
+    auto type = adaptor.getOperands()[0].getType();
+
+    // Iteratively extract the array dimensions from the type.
+    while (auto t = type.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
+      dims.push_back(t.getNumElements());
+      type = t.getElementType();
+    }
+
+    SmallVector<uint64_t> lBounds;
+    SmallVector<uint64_t> uBounds;
+
+    // Extract integer value from the attribute
+    SmallVector<int64_t> coordinates = llvm::to_vector<4>(
+        llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
+          return a.cast<IntegerAttr>().getInt();
+        }));
+
+    // Unzip the upper and lower bound and convert to a row major format.
+    for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
+      uBounds.push_back(*i++);
+      lBounds.push_back(*i);
+    }
+
+    auto &subscripts = lBounds;
+    auto loc = range.getLoc();
+    mlir::Value lastOp = adaptor.getOperands()[0];
+    mlir::Value insertVal = adaptor.getOperands()[1];
+
+    auto i64Ty = rewriter.getI64Type();
+    while (subscripts != uBounds) {
+      // Convert uint64_t's to Attribute's.
+      SmallVector<mlir::Attribute> subscriptAttrs;
+      for (const auto &subscript : subscripts)
+        subscriptAttrs.push_back(IntegerAttr::get(i64Ty, subscript));
+      lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
+          loc, ty, lastOp, insertVal,
+          ArrayAttr::get(range.getContext(), subscriptAttrs));
+
+      incrementSubscripts(dims, subscripts);
+    }
+
+    // Convert uint64_t's to Attribute's.
+    SmallVector<mlir::Attribute> subscriptAttrs;
+    for (const auto &subscript : subscripts)
+      subscriptAttrs.push_back(
+          IntegerAttr::get(rewriter.getI64Type(), subscript));
+    mlir::ArrayRef<mlir::Attribute> arrayRef(subscriptAttrs);
+
+    rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
+        range, ty, lastOp, insertVal,
+        ArrayAttr::get(range.getContext(), arrayRef));
+
+    return success();
+  }
+};
 } // namespace
 
 namespace {
@@ -221,10 +317,9 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
     auto *context = getModule().getContext();
     fir::LLVMTypeConverter typeConverter{getModule()};
     mlir::OwningRewritePatternList pattern(context);
-    pattern
-        .insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
-                UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>(
-            typeConverter);
+    pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
+                   InsertOnRangeOpConversion, UndefOpConversion,
+                   UnreachableOpConversion, ZeroOpConversion>(typeConverter);
     mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
     mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
                                                             pattern);

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 3ef31ac28818d..a977dac869eab 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -84,6 +84,28 @@ fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
 
 // -----
 
+// Test global with insert_on_range operation not covering the full array
+// in initializer region.
+
+fir.global internal @_QEmultiarray : !fir.array<32xi32> {
+  %c0_i32 = arith.constant 1 : i32
+  %0 = fir.undefined !fir.array<32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
+  fir.has_value %2 : !fir.array<32xi32>
+}
+
+// CHECK:          llvm.mlir.global internal @_QEmultiarray() : !llvm.array<32 x i32> {
+// CHECK:            %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:            %{{.*}} = llvm.mlir.undef : !llvm.array<32 x i32>
+// CHECK:            %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[5] : !llvm.array<32 x i32>
+// CHECK-COUNT-24:   %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[{{.*}}] : !llvm.array<32 x i32>
+// CHECK:            %{{.*}} = llvm.insertvalue %[[CST]], %{{.*}}[31] : !llvm.array<32 x i32>
+// CHECK-NOT:        llvm.insertvalue
+// CHECK:            llvm.return %{{.*}} : !llvm.array<32 x i32>
+// CHECK:          }
+
+// -----
+
 // Test fir.zero_bits operation with LLVM ptr type
 
 func @zero_test_ptr() {


        


More information about the flang-commits mailing list