[flang-commits] [flang] 319f022 - [flang] hlfir.associate and hlfir.end_associate codegen

Jean Perier via flang-commits flang-commits at lists.llvm.org
Thu Dec 1 08:59:29 PST 2022


Author: Jean Perier
Date: 2022-12-01T17:58:28+01:00
New Revision: 319f0221eef90734a7b5fc719df07e3d5c693105

URL: https://github.com/llvm/llvm-project/commit/319f0221eef90734a7b5fc719df07e3d5c693105
DIFF: https://github.com/llvm/llvm-project/commit/319f0221eef90734a7b5fc719df07e3d5c693105.diff

LOG: [flang] hlfir.associate and hlfir.end_associate codegen

Add hlfir.associate and hlfir.end_associate codegen.
To properly allow reusing the bufferized expression storage for the
newly created variable, bufferization of hlfir.expr has to be updated
so that hlfir.expr are translated to a variable and a boolean to
indicate if the variable storage needs to be freed after the expression
was used. That way the responsibility to free the bufferized expression
can be passed to the variable user, and applied in the
hlfir.end_associate.

Right now, not of the bufferized expression are heap allocated, so
generating the conditional freemem in hlfir.end_associate is left as
a TODO for when it can be tested.

Differential Revision: https://reviews.llvm.org/D139020

Added: 
    flang/test/HLFIR/associate-codegen.fir

Modified: 
    flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 4a8380c2bbbd8..f5a70c8112a1a 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -7,6 +7,9 @@
 //===----------------------------------------------------------------------===//
 // This file defines a pass that bufferize hlfir.expr. It translates operations
 // producing or consuming hlfir.expr into operations operating on memory.
+// An hlfir.expr is translated to a tuple<variable address, cleanupflag>
+// where cleanupflag is set to true if storage for the expression was allocated
+// on the heap.
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/Character.h"
@@ -33,6 +36,64 @@ namespace hlfir {
 
 namespace {
 
+/// Helper to create tuple from a bufferized expr storage and clean up
+/// instruction flag.
+static mlir::Value packageBufferizedExpr(mlir::Location loc,
+                                         fir::FirOpBuilder &builder,
+                                         mlir::Value storage,
+                                         mlir::Value mustFree) {
+  auto tupleType = mlir::TupleType::get(
+      builder.getContext(),
+      mlir::TypeRange{storage.getType(), mustFree.getType()});
+  auto undef = builder.create<fir::UndefOp>(loc, tupleType);
+  auto insert = builder.create<fir::InsertValueOp>(
+      loc, tupleType, undef, mustFree,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 1)}));
+  return builder.create<fir::InsertValueOp>(
+      loc, tupleType, insert, storage,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 0)}));
+}
+
+/// Helper to create tuple from a bufferized expr storage and constant
+/// boolean clean-up flag.
+static mlir::Value packageBufferizedExpr(mlir::Location loc,
+                                         fir::FirOpBuilder &builder,
+                                         mlir::Value storage, bool mustFree) {
+  mlir::Value mustFreeValue = builder.createBool(loc, mustFree);
+  return packageBufferizedExpr(loc, builder, storage, mustFreeValue);
+}
+
+/// Helper to extract the storage from a tuple created by packageBufferizedExpr.
+/// It assumes no tuples are used as HLFIR operation operands, which is
+/// currently enforced by the verifiers that only accept HLFIR value or
+/// variable types which do not include tuples.
+static mlir::Value getBufferizedExprStorage(mlir::Value bufferizedExpr) {
+  auto tupleType = bufferizedExpr.getType().dyn_cast<mlir::TupleType>();
+  if (!tupleType)
+    return bufferizedExpr;
+  assert(tupleType.size() == 2 && "unexpected tuple type");
+  if (auto insert = bufferizedExpr.getDefiningOp<fir::InsertValueOp>())
+    if (insert.getVal().getType() == tupleType.getType(0))
+      return insert.getVal();
+  TODO(bufferizedExpr.getLoc(), "general extract storage case");
+}
+
+/// Helper to extract the clean-up flag from a tuple created by
+/// packageBufferizedExpr.
+static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) {
+  auto tupleType = bufferizedExpr.getType().dyn_cast<mlir::TupleType>();
+  if (!tupleType)
+    return bufferizedExpr;
+  assert(tupleType.size() == 2 && "unexpected tuple type");
+  if (auto insert = bufferizedExpr.getDefiningOp<fir::InsertValueOp>())
+    if (auto insert0 = insert.getAdt().getDefiningOp<fir::InsertValueOp>())
+      if (insert0.getVal().getType() == tupleType.getType(1))
+        return insert0.getVal();
+  TODO(bufferizedExpr.getLoc(), "general extract storage case");
+}
+
 struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
   using mlir::OpConversionPattern<hlfir::AssignOp>::OpConversionPattern;
   explicit AssignOpConversion(mlir::MLIRContext *ctx)
@@ -41,7 +102,8 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
   matchAndRewrite(hlfir::AssignOp assign, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
     rewriter.replaceOpWithNewOp<hlfir::AssignOp>(
-        assign, adaptor.getOperands()[0], adaptor.getOperands()[1]);
+        assign, getBufferizedExprStorage(adaptor.getOperands()[0]),
+        getBufferizedExprStorage(adaptor.getOperands()[1]));
     return mlir::success();
   }
 };
@@ -61,36 +123,106 @@ struct ConcatOpConversion : public mlir::OpConversionPattern<hlfir::ConcatOp> {
     if (adaptor.getStrings().size() > 2)
       TODO(loc, "codegen of optimized chained concatenation of more than two "
                 "strings");
-    hlfir::Entity lhs{adaptor.getStrings()[0]};
-    hlfir::Entity rhs{adaptor.getStrings()[1]};
+    hlfir::Entity lhs{getBufferizedExprStorage(adaptor.getStrings()[0])};
+    hlfir::Entity rhs{getBufferizedExprStorage(adaptor.getStrings()[1])};
     auto [lhsExv, c1] = hlfir::translateToExtendedValue(loc, builder, lhs);
     auto [rhsExv, c2] = hlfir::translateToExtendedValue(loc, builder, rhs);
     assert(!c1 && !c2 && "expected variables");
     fir::ExtendedValue res =
         fir::factory::CharacterExprHelper{builder, loc}.createConcatenate(
             *lhsExv.getCharBox(), *rhsExv.getCharBox());
+    // Ensure the memory type is the same as the result type.
+    mlir::Type addrType = fir::ReferenceType::get(
+        hlfir::getFortranElementType(concat.getResult().getType()));
+    mlir::Value cast = builder.createConvert(loc, addrType, fir::getBase(res));
+    res = fir::substBase(res, cast);
     auto hlfirTempRes = hlfir::genDeclare(loc, builder, res, "tmp",
                                           fir::FortranVariableFlagsAttr{});
-    rewriter.replaceOp(concat, hlfirTempRes);
+    mlir::Value bufferizedExpr =
+        packageBufferizedExpr(loc, builder, hlfirTempRes, false);
+    rewriter.replaceOp(concat, bufferizedExpr);
     return mlir::success();
   }
 };
 
+struct AssociateOpConversion
+    : public mlir::OpConversionPattern<hlfir::AssociateOp> {
+  using mlir::OpConversionPattern<hlfir::AssociateOp>::OpConversionPattern;
+  explicit AssociateOpConversion(mlir::MLIRContext *ctx)
+      : mlir::OpConversionPattern<hlfir::AssociateOp>{ctx} {}
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Location loc = associate->getLoc();
+    // If this is the last use of the expression value and this is an hlfir.expr
+    // that was bufferized, re-use the storage.
+    // Otherwise, create a temp and assign the storage to it.
+    mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource());
+    const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType());
+
+    auto replaceWith = [&](mlir::Value hlfirVar, mlir::Value firVar,
+                           mlir::Value flag) {
+      associate.getResult(0).replaceAllUsesWith(hlfirVar);
+      associate.getResult(1).replaceAllUsesWith(firVar);
+      associate.getResult(2).replaceAllUsesWith(flag);
+      rewriter.replaceOp(associate, {hlfirVar, firVar, flag});
+    };
+
+    if (!isTrivialValue && associate.getSource().hasOneUse()) {
+      mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getSource());
+      mlir::Value firBase = hlfir::Entity{bufferizedExpr}.getFirBase();
+      replaceWith(bufferizedExpr, firBase, mustFree);
+      return mlir::success();
+    }
+    if (isTrivialValue) {
+      auto module = associate->getParentOfType<mlir::ModuleOp>();
+      fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module));
+      auto temp = builder.createTemporary(loc, bufferizedExpr.getType(),
+                                          associate.getUniqName());
+      builder.create<fir::StoreOp>(loc, bufferizedExpr, temp);
+      mlir::Value mustFree = builder.createBool(loc, false);
+      replaceWith(temp, temp, mustFree);
+      return mlir::success();
+    }
+    TODO(loc, "hlfir.associate of hlfir.expr with more than one use");
+  }
+};
+
+struct EndAssociateOpConversion
+    : public mlir::OpConversionPattern<hlfir::EndAssociateOp> {
+  using mlir::OpConversionPattern<hlfir::EndAssociateOp>::OpConversionPattern;
+  explicit EndAssociateOpConversion(mlir::MLIRContext *ctx)
+      : mlir::OpConversionPattern<hlfir::EndAssociateOp>{ctx} {}
+  mlir::LogicalResult
+  matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    mlir::Value mustFree = adaptor.getMustFree();
+    if (auto cstMustFree = fir::factory::getIntIfConstant(mustFree))
+      if (*cstMustFree == 0) {
+        rewriter.eraseOp(endAssociate);
+        return mlir::success(); // nothing to do.
+      }
+    TODO(endAssociate.getLoc(), "conditional free");
+  }
+};
+
 class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
 public:
   void runOnOperation() override {
     // TODO: make this a pass operating on FuncOp. The issue is that
     // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm
     // intrinsics calls creation. This may create race conflict if the pass is
-    // scheduleed on FuncOp. A solution could be to provide an optional mutex
+    // scheduled on FuncOp. A solution could be to provide an optional mutex
     // when building a FirOpBuilder and locking around FuncOp and GlobalOp
     // creation, but this needs a bit more thinking, so at this point the pass
     // is scheduled on the moduleOp.
     auto module = this->getOperation();
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    patterns.insert<AssignOpConversion, ConcatOpConversion>(context);
+    patterns.insert<AssignOpConversion, AssociateOpConversion,
+                    ConcatOpConversion, EndAssociateOpConversion>(context);
     mlir::ConversionTarget target(*context);
+    target.addIllegalOp<hlfir::AssociateOp, hlfir::EndAssociateOp>();
     target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) {
       return llvm::all_of(
                  op->getResultTypes(),

diff  --git a/flang/test/HLFIR/associate-codegen.fir b/flang/test/HLFIR/associate-codegen.fir
new file mode 100644
index 0000000000000..8eae3fd2f223a
--- /dev/null
+++ b/flang/test/HLFIR/associate-codegen.fir
@@ -0,0 +1,85 @@
+// Test hlfir.associate/hlfir.end_associate operation code generation to FIR.
+
+// RUN: fir-opt %s -bufferize-hlfir | FileCheck %s
+
+func.func @associate_int() {
+  %c42_i32 = arith.constant 42 : i32
+  %0:3 = hlfir.associate %c42_i32 {uniq_name = "x"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
+  fir.call @take_i4(%0#0) : (!fir.ref<i32>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<i32>, i1
+  return
+}
+// CHECK-LABEL:   func.func @associate_int() {
+// CHECK:  %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "x"}
+// CHECK:  %[[VAL_1:.*]] = arith.constant 42 : i32
+// CHECK:  fir.store %[[VAL_1]] to %[[VAL_0]] : !fir.ref<i32>
+// CHECK:  %[[VAL_2:.*]] = arith.constant false
+// CHECK:  fir.call @take_i4(%[[VAL_0]]) : (!fir.ref<i32>) -> ()
+// CHECK-NOT: fir.freemem
+
+
+func.func @associate_real() {
+  %cst = arith.constant 4.200000e-01 : f32
+  %0:3 = hlfir.associate %cst {uniq_name = "x"} : (f32) -> (!fir.ref<f32>, !fir.ref<f32>, i1)
+  fir.call @take_r4(%0#0) : (!fir.ref<f32>) -> ()
+  hlfir.end_associate %0#1, %0#2 : !fir.ref<f32>, i1
+  return
+}
+// CHECK-LABEL:   func.func @associate_real() {
+// CHECK:  %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "x"}
+// CHECK:  %[[VAL_1:.*]] = arith.constant 4.200000e-01 : f32
+// CHECK:  fir.store %[[VAL_1]] to %[[VAL_0]] : !fir.ref<f32>
+// CHECK:  %[[VAL_2:.*]] = arith.constant false
+// CHECK:  fir.call @take_r4(%[[VAL_0]]) : (!fir.ref<f32>) -> ()
+// CHECK-NOT: fir.freemem
+
+
+func.func @associate_logical() {
+  %true = arith.constant true
+  %0 = fir.convert %true : (i1) -> !fir.logical<4>
+  %1:3 = hlfir.associate %0 {uniq_name = "x"} : (!fir.logical<4>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>, i1)
+  fir.call @take_l4(%1#0) : (!fir.ref<!fir.logical<4>>) -> ()
+  hlfir.end_associate %1#1, %1#2 : !fir.ref<!fir.logical<4>>, i1
+  return
+}
+// CHECK-LABEL:   func.func @associate_logical() {
+// CHECK:  %[[VAL_0:.*]] = fir.alloca !fir.logical<4> {bindc_name = "x"}
+// CHECK:  %[[VAL_1:.*]] = arith.constant true
+// CHECK:  %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (i1) -> !fir.logical<4>
+// CHECK:  fir.store %[[VAL_2]] to %[[VAL_0]] : !fir.ref<!fir.logical<4>>
+// CHECK:  %[[VAL_3:.*]] = arith.constant false
+// CHECK:  fir.call @take_l4(%[[VAL_0]]) : (!fir.ref<!fir.logical<4>>) -> ()
+// CHECK-NOT: fir.freemem
+
+
+func.func @associate_char(%arg0: !fir.boxchar<1> ) {
+  %0:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+  %1:2 = hlfir.declare %0#0 typeparams %0#1 {uniq_name = "x"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
+  %2 = arith.addi %0#1, %0#1 : index
+  %3 = hlfir.concat %1#0, %1#0 len %2 : (!fir.boxchar<1>, !fir.boxchar<1>, index) -> !hlfir.expr<!fir.char<1,?>>
+  %4:3 = hlfir.associate %3 typeparams %2 {uniq_name = "x"} : (!hlfir.expr<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>, i1)
+  fir.call @take_c(%4#0) : (!fir.boxchar<1>) -> ()
+  hlfir.end_associate %4#1, %4#2 : !fir.ref<!fir.char<1,?>>, i1
+  return
+}
+// CHECK-LABEL:   func.func @associate_char(
+// CHECK-SAME:    %[[VAL_0:.*]]: !fir.boxchar<1>) {
+// CHECK:  %[[VAL_1:.*]]:2 = fir.unboxchar %[[VAL_0]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+// CHECK:  %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]]#0 typeparams %[[VAL_1]]#1 {uniq_name = "x"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
+// CHECK:  %[[VAL_3:.*]] = arith.addi %[[VAL_1]]#1, %[[VAL_1]]#1 : index
+// CHECK:  %[[VAL_4:.*]] = arith.addi %[[VAL_1]]#1, %[[VAL_1]]#1 : index
+// CHECK:  %[[VAL_5:.*]] = fir.alloca !fir.char<1,?>(%[[VAL_4]] : index) {bindc_name = ".chrtmp"}
+// CHECK:  fir.call @llvm.memmove.p0.p0.i64
+// CHECK:  %[[VAL_21:.*]]:2 = hlfir.declare %[[VAL_5]] typeparams %[[VAL_4]] {uniq_name = "tmp"} : (!fir.ref<!fir.char<1,?>>, index) -> (!fir.boxchar<1>, !fir.ref<!fir.char<1,?>>)
+// CHECK:  %[[VAL_22:.*]] = arith.constant false
+// CHECK:  %[[VAL_23:.*]] = fir.undefined tuple<!fir.boxchar<1>, i1>
+// CHECK:  %[[VAL_24:.*]] = fir.insert_value %[[VAL_23]], %[[VAL_22]], [1 : index] : (tuple<!fir.boxchar<1>, i1>, i1) -> tuple<!fir.boxchar<1>, i1>
+// CHECK:  %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_21]]#0, [0 : index] : (tuple<!fir.boxchar<1>, i1>, !fir.boxchar<1>) -> tuple<!fir.boxchar<1>, i1>
+// CHECK:  fir.call @take_c(%[[VAL_21]]#0) : (!fir.boxchar<1>) -> ()
+// CHECK-NOT: fir.freemem
+
+
+func.func private @take_i4(!fir.ref<i32>)
+func.func private @take_r4(!fir.ref<f32>)
+func.func private @take_l4(!fir.ref<!fir.logical<4>>)
+func.func private @take_c(!fir.boxchar<1>)


        


More information about the flang-commits mailing list