[flang-commits] [flang] 416e503 - [flang] split character procedure arguments in target-rewrite pass

Jean Perier via flang-commits flang-commits at lists.llvm.org
Thu Jan 27 07:30:24 PST 2022


Author: Jean Perier
Date: 2022-01-27T16:29:37+01:00
New Revision: 416e503adfc16b8ae7718c4cca4cc34a6158eea0

URL: https://github.com/llvm/llvm-project/commit/416e503adfc16b8ae7718c4cca4cc34a6158eea0
DIFF: https://github.com/llvm/llvm-project/commit/416e503adfc16b8ae7718c4cca4cc34a6158eea0.diff

LOG: [flang] split character procedure arguments in target-rewrite pass

When passing a character procedure as a dummy procedure, the result
length must be passed along the function address. This is to cover
the cases where the dummy procedure is declared with assumed length
inside the scope that will call it (it will need the length to allocate
the result on the caller side).

To be compatible with other Fortran compiler, this length must be
appended after all other argument just like character objects
(fir.boxchar).

A fir.boxchar cannot be used to implement this feature because it
is meant to take an object address, not a function address.

Instead, argument like `tuple<function type, integer type> {fir.char_proc}`
will be recognized as being character dummy procedure in FIR. That way
lowering does not have to do the argument split.

This patch adds tools in Character.h to create this type and tuple
values as well as to recognize them and extract its tuple members.

It also updates the target rewrite pass to split these arguments like
fir.boxchar.

This part is part of fir-dev upstreaming. It was reviwed previously
in: https://github.com/flang-compiler/f18-llvm-project/pull/1393

Differential Revision: https://reviews.llvm.org/D118108

Added: 
    flang/test/Fir/target-rewrite-char-proc.fir

Modified: 
    flang/include/flang/Optimizer/Builder/Character.h
    flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
    flang/lib/Optimizer/Builder/Character.cpp
    flang/lib/Optimizer/CodeGen/CMakeLists.txt
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Builder/Character.h b/flang/include/flang/Optimizer/Builder/Character.h
index 610b42052f31..b215e39d4c30 100644
--- a/flang/include/flang/Optimizer/Builder/Character.h
+++ b/flang/include/flang/Optimizer/Builder/Character.h
@@ -187,6 +187,39 @@ mlir::FuncOp getLlvmMemmove(FirOpBuilder &builder);
 mlir::FuncOp getLlvmMemset(FirOpBuilder &builder);
 mlir::FuncOp getRealloc(FirOpBuilder &builder);
 
+//===----------------------------------------------------------------------===//
+// Tools to work with Character dummy procedures
+//===----------------------------------------------------------------------===//
+
+/// Create a tuple<function type, length type> type to pass character functions
+/// as arguments along their length. The function type set in the tuple is the
+/// one provided by \p funcPointerType.
+mlir::Type getCharacterProcedureTupleType(mlir::Type funcPointerType);
+
+/// Is this tuple type holding a character function and its result length ?
+bool isCharacterProcedureTuple(mlir::Type type);
+
+/// Is \p tuple a value holding a character function address and its result
+/// length ?
+inline bool isCharacterProcedureTuple(mlir::Value tuple) {
+  return isCharacterProcedureTuple(tuple.getType());
+}
+
+/// Create a tuple<addr, len> given \p addr and \p len as well as the tuple
+/// type \p argTy. \p addr must be any function address, and \p len must be
+/// any integer. Converts will be inserted if needed if \addr and \p len
+/// types are not the same as the one inside the tuple type \p tupleType.
+mlir::Value createCharacterProcedureTuple(fir::FirOpBuilder &builder,
+                                          mlir::Location loc,
+                                          mlir::Type tupleType,
+                                          mlir::Value addr, mlir::Value len);
+
+/// Given a tuple containing a character function address and its result length,
+/// extract the tuple into a pair of value <function address, result length>.
+std::pair<mlir::Value, mlir::Value>
+extractCharacterProcedureTuple(fir::FirOpBuilder &builder, mlir::Location loc,
+                               mlir::Value tuple);
+
 } // namespace fir::factory
 
 #endif // FORTRAN_OPTIMIZER_BUILDER_CHARACTER_H

diff  --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index 7d8aa45b0b07..574f286818c3 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -68,6 +68,12 @@ constexpr llvm::StringRef getOptionalAttrName() { return "fir.optional"; }
 /// Attribute to mark Fortran entities with the TARGET attribute.
 static constexpr llvm::StringRef getTargetAttrName() { return "fir.target"; }
 
+/// Attribute to mark that a function argument is a character dummy procedure.
+/// Character dummy procedure have special ABI constraints.
+static constexpr llvm::StringRef getCharacterProcedureDummyAttrName() {
+  return "fir.char_proc";
+}
+
 /// Tell if \p value is:
 ///   - a function argument that has attribute \p attributeName
 ///   - or, the result of fir.alloca/fir.allocamem op that has attribute \p

diff  --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp
index b306d5431300..8b84170d663a 100644
--- a/flang/lib/Optimizer/Builder/Character.cpp
+++ b/flang/lib/Optimizer/Builder/Character.cpp
@@ -725,3 +725,51 @@ mlir::Value fir::factory::CharacterExprHelper::getLength(mlir::Value memref) {
   // Length cannot be deduced from memref.
   return {};
 }
+
+std::pair<mlir::Value, mlir::Value>
+fir::factory::extractCharacterProcedureTuple(fir::FirOpBuilder &builder,
+                                             mlir::Location loc,
+                                             mlir::Value tuple) {
+  mlir::TupleType tupleType = tuple.getType().cast<mlir::TupleType>();
+  mlir::Value addr = builder.create<fir::ExtractValueOp>(
+      loc, tupleType.getType(0), tuple,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 0)}));
+  mlir::Value len = builder.create<fir::ExtractValueOp>(
+      loc, tupleType.getType(1), tuple,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 1)}));
+  return {addr, len};
+}
+
+mlir::Value fir::factory::createCharacterProcedureTuple(
+    fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type argTy,
+    mlir::Value addr, mlir::Value len) {
+  mlir::TupleType tupleType = argTy.cast<mlir::TupleType>();
+  addr = builder.createConvert(loc, tupleType.getType(0), addr);
+  len = builder.createConvert(loc, tupleType.getType(1), len);
+  mlir::Value tuple = builder.create<fir::UndefOp>(loc, tupleType);
+  tuple = builder.create<fir::InsertValueOp>(
+      loc, tupleType, tuple, addr,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 0)}));
+  tuple = builder.create<fir::InsertValueOp>(
+      loc, tupleType, tuple, len,
+      builder.getArrayAttr(
+          {builder.getIntegerAttr(builder.getIndexType(), 1)}));
+  return tuple;
+}
+
+bool fir::factory::isCharacterProcedureTuple(mlir::Type ty) {
+  mlir::TupleType tuple = ty.dyn_cast<mlir::TupleType>();
+  return tuple && tuple.size() == 2 &&
+         tuple.getType(0).isa<mlir::FunctionType>() &&
+         fir::isa_integer(tuple.getType(1));
+}
+
+mlir::Type
+fir::factory::getCharacterProcedureTupleType(mlir::Type funcPointerType) {
+  mlir::MLIRContext *context = funcPointerType.getContext();
+  mlir::Type lenType = mlir::IntegerType::get(context, 64);
+  return mlir::TupleType::get(context, {funcPointerType, lenType});
+}

diff  --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index b6a63c8ee8b8..04016c506ebc 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -6,12 +6,14 @@ add_flang_library(FIRCodeGen
   TargetRewrite.cpp
 
   DEPENDS
+  FIRBuilder
   FIRDialect
   FIRSupport
   FIROptCodeGenPassIncGen
   CGOpsIncGen
 
   LINK_LIBS
+  FIRBuilder
   FIRDialect
   FIRSupport
   MLIROpenMPToLLVM

diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index f5616d21d134..d4659c1150bc 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -17,9 +17,11 @@
 #include "PassDetail.h"
 #include "Target.h"
 #include "flang/Lower/Todo.h"
+#include "flang/Optimizer/Builder/Character.h"
 #include "flang/Optimizer/CodeGen/CodeGen.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Support/FIRContext.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -42,7 +44,8 @@ struct FixupTy {
     ReturnAsStore,
     ReturnType,
     Split,
-    Trailing
+    Trailing,
+    TrailingCharProc
   };
 
   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
@@ -266,6 +269,41 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
           })
+          .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+            if (factory::isCharacterProcedureTuple(tuple)) {
+              mlir::ModuleOp module = getModule();
+              if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
+                if (callOp.callee()) {
+                  llvm::StringRef charProcAttr =
+                      fir::getCharacterProcedureDummyAttrName();
+                  // The charProcAttr attribute is only used as a safety to
+                  // confirm that this is a dummy procedure and should be split.
+                  // It cannot be used to match because attributes are not
+                  // available in case of indirect calls.
+                  auto funcOp =
+                      module.lookupSymbol<mlir::FuncOp>(*callOp.callee());
+                  if (funcOp &&
+                      !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
+                          index, charProcAttr))
+                    mlir::emitError(loc, "tuple argument will be split even "
+                                         "though it does not have the `" +
+                                             charProcAttr + "` attribute");
+                }
+              }
+              mlir::Type funcPointerType = tuple.getType(0);
+              mlir::Type lenType = tuple.getType(1);
+              FirOpBuilder builder(*rewriter, getKindMapping(module));
+              auto [funcPointer, len] =
+                  factory::extractCharacterProcedureTuple(builder, loc, oper);
+              newInTys.push_back(funcPointerType);
+              newOpers.push_back(funcPointer);
+              trailingInTys.push_back(lenType);
+              trailingOpers.push_back(len);
+            } else {
+              newInTys.push_back(tuple);
+              newOpers.push_back(oper);
+            }
+          })
           .Default([&](mlir::Type ty) {
             newInTys.push_back(ty);
             newOpers.push_back(oper);
@@ -360,6 +398,14 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
             lowerComplexSignatureArg(ty, newInTys);
           })
+          .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+            if (factory::isCharacterProcedureTuple(tuple)) {
+              newInTys.push_back(tuple.getType(0));
+              trailingInTys.push_back(tuple.getType(1));
+            } else {
+              newInTys.push_back(ty);
+            }
+          })
           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
     }
     // append trailing input types
@@ -394,7 +440,8 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
         return false;
       }
     for (auto ty : func.getInputs())
-      if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
+      if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) &&
+           !noCharacterConversion) ||
           (isa_complex(ty) && !noComplexConversion)) {
         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
         return false;
@@ -476,6 +523,16 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
             else
               doComplexArg(func, cmplx, newInTys, fixups);
           })
+          .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+            if (factory::isCharacterProcedureTuple(tuple)) {
+              fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
+                                  newInTys.size(), trailingTys.size());
+              newInTys.push_back(tuple.getType(0));
+              trailingTys.push_back(tuple.getType(1));
+            } else {
+              newInTys.push_back(ty);
+            }
+          })
           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
     }
 
@@ -604,6 +661,23 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
           func.front().eraseArgument(fixup.index + 1);
         } break;
+        case FixupTy::Codes::TrailingCharProc: {
+          // The FIR character procedure argument tuple has been split into a
+          // pair of distinct arguments. The first part of the pair appears in
+          // the original argument position. The second part of the pair is
+          // appended after all the original arguments.
+          auto newProcPointerArg = func.front().insertArgument(
+              fixup.index, newInTys[fixup.index], loc);
+          auto newLenArg =
+              func.front().addArgument(trailingTys[fixup.second], loc);
+          auto tupleType = oldArgTys[fixup.index - offset];
+          rewriter->setInsertionPointToStart(&func.front());
+          FirOpBuilder builder(*rewriter, getKindMapping(getModule()));
+          auto tuple = factory::createCharacterProcedureTuple(
+              builder, loc, tupleType, newProcPointerArg, newLenArg);
+          func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
+          func.front().eraseArgument(fixup.index + 1);
+        } break;
         }
       }
     }

diff  --git a/flang/test/Fir/target-rewrite-char-proc.fir b/flang/test/Fir/target-rewrite-char-proc.fir
new file mode 100644
index 000000000000..cb3e68e4aa95
--- /dev/null
+++ b/flang/test/Fir/target-rewrite-char-proc.fir
@@ -0,0 +1,69 @@
+// Test rewrite of character procedure pointer tuple argument to two 
diff erent
+// arguments: one for the function address, and one for the length. The length
+// argument is added after other characters.
+// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+
+// CHECK:  func private @takes_char_proc(() -> () {fir.char_proc}, i64)
+func private @takes_char_proc(tuple<() -> (), i64> {fir.char_proc})
+
+func private @takes_char(!fir.boxchar<1>)
+func private @char_proc(!fir.ref<!fir.char<1,7>>, index) -> !fir.boxchar<1>
+
+func @_QPcst_len() {
+  %0 = fir.address_of(@char_proc) : (!fir.ref<!fir.char<1,7>>, index) -> !fir.boxchar<1>
+  %c7_i64 = arith.constant 7 : i64
+  %1 = fir.convert %0 : ((!fir.ref<!fir.char<1,7>>, index) -> !fir.boxchar<1>) -> (() -> ())
+  %2 = fir.undefined tuple<() -> (), i64>
+  %3 = fir.insert_value %2, %1, [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+  %4 = fir.insert_value %3, %c7_i64, [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+
+  // CHECK:  %[[PROC_ADDR:.*]] = fir.extract_value %{{.*}}, [0 : index] : (tuple<() -> (), i64>) -> (() -> ())
+  // CHECK:  %[[LEN:.*]] = fir.extract_value %{{.*}}, [1 : index] : (tuple<() -> (), i64>) -> i64
+  // CHECK:  fir.call @takes_char_proc(%[[PROC_ADDR]], %[[LEN]]) : (() -> (), i64) -> ()
+  fir.call @takes_char_proc(%4) : (tuple<() -> (), i64>) -> ()
+  return
+}
+
+// CHECK:  func @test_dummy_proc_that_takes_dummy_char_proc(
+// CHECK-SAME: %[[ARG0:.*]]: () -> ()) {
+func @test_dummy_proc_that_takes_dummy_char_proc(%arg0: () -> ()) {
+  %0 = fir.address_of(@char_proc) : (!fir.ref<!fir.char<1,7>>, index) -> !fir.boxchar<1>
+  %c7_i64 = arith.constant 7 : i64
+  %1 = fir.convert %0 : ((!fir.ref<!fir.char<1,7>>, index) -> !fir.boxchar<1>) -> (() -> ())
+  %2 = fir.undefined tuple<() -> (), i64>
+  %3 = fir.insert_value %2, %1, [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+  %4 = fir.insert_value %3, %c7_i64, [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+  %5 = fir.convert %arg0 : (() -> ()) -> ((tuple<() -> (), i64>) -> ())
+
+  // CHECK:  %[[ARG_CAST:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> ((() -> (), i64) -> ())
+  // CHECK:  %[[PROC_ADDR:.*]] = fir.extract_value %4, [0 : index] : (tuple<() -> (), i64>) -> (() -> ())
+  // CHECK:  %[[PROC_LEN:.*]] = fir.extract_value %4, [1 : index] : (tuple<() -> (), i64>) -> i64
+  // CHECK: fir.call %[[ARG_CAST]](%[[PROC_ADDR]], %[[PROC_LEN]]) : (() -> (), i64) -> ()
+  fir.call %5(%4) : (tuple<() -> (), i64>) -> ()
+  return
+}
+
+// CHECK:  func @takes_dummy_char_proc_impl(
+// CHECK-SAME: %[[PROC_ADDR:.*]]: () -> () {fir.char_proc},
+// CHECK-SAME: %[[C_ADDR:.*]]: !fir.ref<!fir.char<1,?>>,
+// CHECK-SAME: %[[PROC_LEN:.*]]: i64,
+// CHECK-SAME: %[[C_LEN:.*]]: i64) {
+func @takes_dummy_char_proc_impl(%arg0: tuple<() -> (), i64> {fir.char_proc}, %arg1: !fir.boxchar<1>) {
+  // CHECK:    %[[UNDEF:.*]] = fir.undefined tuple<() -> (), i64>
+  // CHECK:    %[[TUPLE0:.*]] = fir.insert_value %[[UNDEF]], %[[PROC_ADDR]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+  // CHECK:    %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[PROC_LEN]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+  %0 = fir.alloca !fir.char<1,7> {bindc_name = ".result"}
+  %1:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
+  %c5 = arith.constant 5 : index
+  %2 = fir.emboxchar %1#0, %c5 : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
+  %3 = fir.extract_value %arg0, [0 : index] : (tuple<() -> (), i64>) -> (() -> ())
+  %c7_i64 = arith.constant 7 : i64
+  %4 = fir.convert %c7_i64 : (i64) -> index
+  %6 = fir.convert %3 : (() -> ()) -> ((!fir.ref<!fir.char<1,7>>, index, !fir.boxchar<1>) -> !fir.boxchar<1>)
+  %7 = fir.call %6(%0, %4, %2) : (!fir.ref<!fir.char<1,7>>, index, !fir.boxchar<1>) -> !fir.boxchar<1>
+  %8 = fir.convert %0 : (!fir.ref<!fir.char<1,7>>) -> !fir.ref<!fir.char<1,?>>
+  %9 = fir.emboxchar %8, %4 : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
+  fir.call @takes_char(%9) : (!fir.boxchar<1>) -> ()
+  return
+}
+


        


More information about the flang-commits mailing list