[flang-commits] [flang] b0eef1e - [fir] Add the abstract result conversion pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Oct 11 01:10:49 PDT 2021


Author: Valentin Clement
Date: 2021-10-11T10:10:41+02:00
New Revision: b0eef1eef0500315bf74721dda3d7a8e3c6a6eac

URL: https://github.com/llvm/llvm-project/commit/b0eef1eef0500315bf74721dda3d7a8e3c6a6eac
DIFF: https://github.com/llvm/llvm-project/commit/b0eef1eef0500315bf74721dda3d7a8e3c6a6eac.diff

LOG: [fir] Add the abstract result conversion pass

Add pass that convert abstract result to function argument.
This pass is needed before the conversion to LLVM IR.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D111146

Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    flang/lib/Optimizer/Transforms/AbstractResult.cpp
    flang/test/Fir/abstract-results.fir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/lib/Optimizer/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index fc689b037297..5dc784ff0b50 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -26,6 +26,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
+std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass> createCharacterConversionPass();
 std::unique_ptr<mlir::Pass> createExternalNameConversionPass();

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index b207ad70ba9a..309ef43d766d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file contains definitions for passes within the Optimizer/Transforms/
-//  directory.
+// directory.
 //
 //===----------------------------------------------------------------------===//
 
@@ -16,6 +16,25 @@
 
 include "mlir/Pass/PassBase.td"
 
+def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::FuncOp"> {
+  let summary = "Convert fir.array, fir.box and fir.rec function result to "
+                "function argument";
+  let description = [{
+    This pass is required before code gen to the LLVM IR dialect,
+    including the pre-cg rewrite pass.
+  }];
+  let constructor = "::fir::createAbstractResultOptPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::StandardOpsDialect"
+  ];
+  let options = [
+    Option<"passResultAsBox", "abstract-result-as-box",
+           "bool", /*default=*/"false",
+           "Pass fir.array<T> result as fir.box<fir.array<T>> argument instead"
+           " of fir.ref<fir.array<T>>.">
+  ];
+}
+
 def AffineDialectPromotion : FunctionPass<"promote-to-affine"> {
   let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
   let description = [{

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 94e8d624b338..33db64c6687f 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -623,7 +623,8 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
                         llvm::ArrayRef<mlir::Type> results,
                         mlir::ValueRange operands) {
   result.addOperands(operands);
-  result.addAttribute(getCalleeAttrName(), callee);
+  if (callee)
+    result.addAttribute(getCalleeAttrName(), callee);
   result.addTypes(results);
 }
 

diff  --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
new file mode 100644
index 000000000000..21df4180e14c
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -0,0 +1,288 @@
+//===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "flang-abstract-result-opt"
+
+namespace fir {
+namespace {
+
+struct AbstractResultOptions {
+  // Always pass result as a fir.box argument.
+  bool boxResult = false;
+  // New function block argument for the result if the current FuncOp had
+  // an abstract result.
+  mlir::Value newArg;
+};
+
+static bool mustConvertCallOrFunc(mlir::FunctionType type) {
+  if (type.getNumResults() == 0)
+    return false;
+  auto resultType = type.getResult(0);
+  return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
+}
+
+static mlir::Type getResultArgumentType(mlir::Type resultType,
+                                        const AbstractResultOptions &options) {
+  return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
+      .Case<fir::SequenceType, fir::RecordType>(
+          [&](mlir::Type type) -> mlir::Type {
+            if (options.boxResult)
+              return fir::BoxType::get(type);
+            return fir::ReferenceType::get(type);
+          })
+      .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+        return fir::ReferenceType::get(type);
+      })
+      .Default([](mlir::Type) -> mlir::Type {
+        llvm_unreachable("bad abstract result type");
+      });
+}
+
+static mlir::FunctionType
+getNewFunctionType(mlir::FunctionType funcTy,
+                   const AbstractResultOptions &options) {
+  auto resultType = funcTy.getResult(0);
+  auto argTy = getResultArgumentType(resultType, options);
+  llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
+  newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
+  return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
+                                 /*resultTypes=*/{});
+}
+
+static bool mustEmboxResult(mlir::Type resultType,
+                            const AbstractResultOptions &options) {
+  return resultType.isa<fir::SequenceType, fir::RecordType>() &&
+         options.boxResult;
+}
+
+class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  CallOpConversion(mlir::MLIRContext *context, const AbstractResultOptions &opt)
+      : OpRewritePattern(context), options{opt} {}
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto loc = callOp.getLoc();
+    auto result = callOp->getResult(0);
+    if (!result.hasOneUse()) {
+      mlir::emitError(loc,
+                      "calls with abstract result must have exactly one user");
+      return mlir::failure();
+    }
+    auto saveResult =
+        mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
+    if (!saveResult) {
+      mlir::emitError(
+          loc, "calls with abstract result must be used in fir.save_result");
+      return mlir::failure();
+    }
+    auto argType = getResultArgumentType(result.getType(), options);
+    auto buffer = saveResult.memref();
+    mlir::Value arg = buffer;
+    if (mustEmboxResult(result.getType(), options))
+      arg = rewriter.create<fir::EmboxOp>(
+          loc, argType, buffer, saveResult.shape(), /*slice*/ mlir::Value{},
+          saveResult.typeparams());
+
+    llvm::SmallVector<mlir::Type> newResultTypes;
+    if (callOp.callee()) {
+      llvm::SmallVector<mlir::Value> newOperands = {arg};
+      newOperands.append(callOp.getOperands().begin(),
+                         callOp.getOperands().end());
+      rewriter.create<fir::CallOp>(loc, callOp.callee().getValue(),
+                                   newResultTypes, newOperands);
+    } else {
+      // Indirect calls.
+      llvm::SmallVector<mlir::Type> newInputTypes = {argType};
+      for (auto operand : callOp.getOperands().drop_front())
+        newInputTypes.push_back(operand.getType());
+      auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
+                                           newResultTypes);
+
+      llvm::SmallVector<mlir::Value> newOperands;
+      newOperands.push_back(
+          rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
+      newOperands.push_back(arg);
+      newOperands.append(callOp.getOperands().begin() + 1,
+                         callOp.getOperands().end());
+      rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
+                                   newOperands);
+    }
+    callOp->dropAllReferences();
+    rewriter.eraseOp(callOp);
+    return mlir::success();
+  }
+
+private:
+  const AbstractResultOptions &options;
+};
+
+class SaveResultOpConversion
+    : public mlir::OpRewritePattern<fir::SaveResultOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  SaveResultOpConversion(mlir::MLIRContext *context)
+      : OpRewritePattern(context) {}
+  mlir::LogicalResult
+  matchAndRewrite(fir::SaveResultOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    rewriter.eraseOp(op);
+    return mlir::success();
+  }
+};
+
+class ReturnOpConversion : public mlir::OpRewritePattern<mlir::ReturnOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  ReturnOpConversion(mlir::MLIRContext *context,
+                     const AbstractResultOptions &opt)
+      : OpRewritePattern(context), options{opt} {}
+  mlir::LogicalResult
+  matchAndRewrite(mlir::ReturnOp ret,
+                  mlir::PatternRewriter &rewriter) const override {
+    rewriter.setInsertionPoint(ret);
+    auto returnedValue = ret.getOperand(0);
+    bool replacedStorage = false;
+    if (auto *op = returnedValue.getDefiningOp())
+      if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
+        auto resultStorage = load.memref();
+        load.memref().replaceAllUsesWith(options.newArg);
+        replacedStorage = true;
+        if (auto *alloc = resultStorage.getDefiningOp())
+          if (alloc->use_empty())
+            rewriter.eraseOp(alloc);
+      }
+    // The result storage may have been optimized out by a memory to
+    // register pass, this is possible for fir.box results, or fir.record
+    // with no length parameters. Simply store the result in the result storage.
+    // at the return point.
+    if (!replacedStorage)
+      rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue,
+                                    options.newArg);
+    rewriter.replaceOpWithNewOp<mlir::ReturnOp>(ret);
+    return mlir::success();
+  }
+
+private:
+  const AbstractResultOptions &options;
+};
+
+class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  AddrOfOpConversion(mlir::MLIRContext *context,
+                     const AbstractResultOptions &opt)
+      : OpRewritePattern(context), options{opt} {}
+  mlir::LogicalResult
+  matchAndRewrite(fir::AddrOfOp addrOf,
+                  mlir::PatternRewriter &rewriter) const override {
+    auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
+    auto newFuncTy = getNewFunctionType(oldFuncTy, options);
+    auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
+                                                    addrOf.symbol());
+    // Rather than converting all op a function pointer might transit through
+    // (e.g calls, stores, loads, converts...), cast new type to the abstract
+    // type. A conversion will be added when calling indirect calls of abstract
+    // types.
+    rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
+    return mlir::success();
+  }
+
+private:
+  const AbstractResultOptions &options;
+};
+
+class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
+public:
+  void runOnOperation() override {
+    auto *context = &getContext();
+    auto func = getOperation();
+    auto loc = func.getLoc();
+    mlir::OwningRewritePatternList patterns(context);
+    mlir::ConversionTarget target = *context;
+    AbstractResultOptions options{passResultAsBox.getValue(),
+                                  /*newArg=*/{}};
+
+    // Convert function type itself if it has an abstract result
+    auto funcTy = func.getType().cast<mlir::FunctionType>();
+    if (mustConvertCallOrFunc(funcTy)) {
+      func.setType(getNewFunctionType(funcTy, options));
+      unsigned zero = 0;
+      if (!func.empty()) {
+        // Insert new argument
+        mlir::OpBuilder rewriter(context);
+        auto resultType = funcTy.getResult(0);
+        auto argTy = getResultArgumentType(resultType, options);
+        options.newArg = func.front().insertArgument(zero, argTy);
+        if (mustEmboxResult(resultType, options)) {
+          auto bufferType = fir::ReferenceType::get(resultType);
+          rewriter.setInsertionPointToStart(&func.front());
+          options.newArg =
+              rewriter.create<fir::BoxAddrOp>(loc, bufferType, options.newArg);
+        }
+        patterns.insert<ReturnOpConversion>(context, options);
+        target.addDynamicallyLegalOp<mlir::ReturnOp>(
+            [](mlir::ReturnOp ret) { return ret.operands().empty(); });
+      }
+    }
+
+    if (func.empty())
+      return;
+
+    // Convert the calls and, if needed,  the ReturnOp in the function body.
+    target.addLegalDialect<fir::FIROpsDialect, mlir::StandardOpsDialect>();
+    target.addIllegalOp<fir::SaveResultOp>();
+    target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
+      return !mustConvertCallOrFunc(call.getFunctionType());
+    });
+    target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
+      if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
+        return !mustConvertCallOrFunc(funTy);
+      return true;
+    });
+    target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+      if (dispatch->getNumResults() != 1)
+        return true;
+      auto resultType = dispatch->getResult(0).getType();
+      if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
+        mlir::emitError(dispatch.getLoc(),
+                        "TODO: dispatchOp with abstract results");
+        return false;
+      }
+      return true;
+    });
+
+    patterns.insert<CallOpConversion>(context, options);
+    patterns.insert<SaveResultOpConversion>(context);
+    patterns.insert<AddrOfOpConversion>(context, options);
+    if (mlir::failed(
+            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+      mlir::emitError(func.getLoc(), "error in converting abstract results\n");
+      signalPassFailure();
+    }
+  }
+};
+} // end anonymous namespace
+} // namespace fir
+
+std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
+  return std::make_unique<AbstractResultOpt>();
+}

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 6465ba8c5599..99b022edb948 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_flang_library(FIRTransforms
+  AbstractResult.cpp
   AffinePromotion.cpp
   AffineDemotion.cpp
   CharacterConversion.cpp

diff  --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
new file mode 100644
index 000000000000..e7b24f268acd
--- /dev/null
+++ b/flang/test/Fir/abstract-results.fir
@@ -0,0 +1,255 @@
+// Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
+// functions that take an additional argument for the result.
+
+// RUN: fir-opt %s --abstract-result-opt | FileCheck %s
+// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX
+
+// ----------------------- Test declaration rewrite ----------------------------
+
+// CHECK-LABEL:  func private @arrayfunc(!fir.ref<!fir.array<?xf32>>, i32)
+// CHECK-BOX-LABEL:  func private @arrayfunc(!fir.box<!fir.array<?xf32>>, i32)
+func private @arrayfunc(i32) -> !fir.array<?xf32>
+
+// CHECK-LABEL:  func private @derivedfunc(!fir.ref<!fir.type<t{x:f32}>>, f32)
+// CHECK-BOX-LABEL:  func private @derivedfunc(!fir.box<!fir.type<t{x:f32}>>, f32)
+func private @derivedfunc(f32) -> !fir.type<t{x:f32}>
+
+// CHECK-LABEL:  func private @boxfunc(!fir.ref<!fir.box<!fir.heap<f64>>>, i64)
+// CHECK-BOX-LABEL:  func private @boxfunc(!fir.ref<!fir.box<!fir.heap<f64>>>, i64)
+func private @boxfunc(i64) -> !fir.box<!fir.heap<f64>>
+
+
+// ------------------------ Test callee rewrite --------------------------------
+
+// CHECK-LABEL:  func private @arrayfunc_callee(
+// CHECK-SAME: %[[buffer:.*]]: !fir.ref<!fir.array<?xf32>>, %[[n:.*]]: index) {
+// CHECK-BOX-LABEL:  func private @arrayfunc_callee(
+// CHECK-BOX-SAME: %[[box:.*]]: !fir.box<!fir.array<?xf32>>, %[[n:.*]]: index) {
+func private @arrayfunc_callee(%n : index) -> !fir.array<?xf32> {
+  %buffer = fir.alloca !fir.array<?xf32>, %n
+  // Do something with result (res(4) = 42.)
+  %c4 = constant 4 : i64
+  %coor = fir.coordinate_of %buffer, %c4 : (!fir.ref<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+  %cst = constant 4.200000e+01 : f32
+  fir.store %cst to %coor : !fir.ref<f32>
+  %res = fir.load %buffer : !fir.ref<!fir.array<?xf32>>
+  return %res : !fir.array<?xf32>
+
+  // CHECK-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+  // CHECK-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref<f32>
+  // CHECK: return
+
+  // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+  // CHECK-BOX-DAG: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref<!fir.array<?xf32>>, i64) -> !fir.ref<f32>
+  // CHECK-BOX-DAG: fir.store %{{.*}} to %[[coor]] : !fir.ref<f32>
+  // CHECK-BOX: return
+}
+
+
+// CHECK-LABEL: func @derivedfunc_callee(
+// CHECK-SAME: %[[buffer:.*]]: !fir.ref<!fir.type<t{x:f32}>>, %[[v:.*]]: f32) {
+// CHECK-BOX-LABEL: func @derivedfunc_callee(
+// CHECK-BOX-SAME: %[[box:.*]]: !fir.box<!fir.type<t{x:f32}>>, %[[v:.*]]: f32) {
+func @derivedfunc_callee(%v: f32) -> !fir.type<t{x:f32}> {
+  %buffer = fir.alloca !fir.type<t{x:f32}>
+  %0 = fir.field_index x, !fir.type<t{x:f32}>
+  %1 = fir.coordinate_of %buffer, %0 : (!fir.ref<!fir.type<t{x:f32}>>, !fir.field) -> !fir.ref<f32>
+  fir.store %v to %1 : !fir.ref<f32>
+  %res = fir.load %buffer : !fir.ref<!fir.type<t{x:f32}>>
+  return %res : !fir.type<t{x:f32}>
+
+  // CHECK: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref<!fir.type<t{x:f32}>>, !fir.field) -> !fir.ref<f32>
+  // CHECK: fir.store %[[v]] to %[[coor]] : !fir.ref<f32>
+  // CHECK: return
+
+  // CHECK-BOX: %[[buffer:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.type<t{x:f32}>>) -> !fir.ref<!fir.type<t{x:f32}>>
+  // CHECK-BOX: %[[coor:.*]] = fir.coordinate_of %[[buffer]], %{{.*}} : (!fir.ref<!fir.type<t{x:f32}>>, !fir.field) -> !fir.ref<f32>
+  // CHECK-BOX: fir.store %[[v]] to %[[coor]] : !fir.ref<f32>
+  // CHECK-BOX: return
+}
+
+// CHECK-LABEL: func @boxfunc_callee(
+// CHECK-SAME: %[[buffer:.*]]: !fir.ref<!fir.box<!fir.heap<f64>>>) {
+// CHECK-BOX-LABEL: func @boxfunc_callee(
+// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref<!fir.box<!fir.heap<f64>>>) {
+func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
+  %alloc = fir.allocmem f64
+  %res = fir.embox %alloc : (!fir.heap<f64>) -> !fir.box<!fir.heap<f64>>
+  return %res : !fir.box<!fir.heap<f64>>
+  // CHECK: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap<f64>) -> !fir.box<!fir.heap<f64>>
+  // CHECK: fir.store %[[box]] to %[[buffer]] : !fir.ref<!fir.box<!fir.heap<f64>>>
+  // CHECK: return
+
+  // CHECK-BOX: %[[box:.*]] = fir.embox %{{.*}} : (!fir.heap<f64>) -> !fir.box<!fir.heap<f64>>
+  // CHECK-BOX: fir.store %[[box]] to %[[buffer]] : !fir.ref<!fir.box<!fir.heap<f64>>>
+  // CHECK-BOX: return
+}
+
+// ------------------------ Test caller rewrite --------------------------------
+
+// CHECK-LABEL: func @call_arrayfunc() {
+// CHECK-BOX-LABEL: func @call_arrayfunc() {
+func @call_arrayfunc() {
+  %c100 = constant 100 : index
+  %buffer = fir.alloca !fir.array<?xf32>, %c100
+  %shape = fir.shape %c100 : (index) -> !fir.shape<1>
+  %res = fir.call @arrayfunc_callee(%c100) : (index) -> !fir.array<?xf32>
+  fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+  return
+
+  // CHECK: %[[c100:.*]] = constant 100 : index
+  // CHECK: %[[buffer:.*]] = fir.alloca !fir.array<?xf32>, %[[c100]]
+  // CHECK: fir.call @arrayfunc_callee(%[[buffer]], %[[c100]]) : (!fir.ref<!fir.array<?xf32>>, index) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[c100:.*]] = constant 100 : index
+  // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array<?xf32>, %[[c100]]
+  // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1>
+  // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+  // CHECK-BOX: fir.call @arrayfunc_callee(%[[box]], %[[c100]]) : (!fir.box<!fir.array<?xf32>>, index) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}
+
+// CHECK-LABEL: func @call_derivedfunc() {
+// CHECK-BOX-LABEL: func @call_derivedfunc() {
+func @call_derivedfunc() {
+  %buffer = fir.alloca !fir.type<t{x:f32}>
+  %cst = constant 4.200000e+01 : f32
+  %res = fir.call @derivedfunc_callee(%cst) : (f32) -> !fir.type<t{x:f32}>
+  fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>
+  return
+  // CHECK: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // CHECK: %[[cst:.*]] = constant {{.*}} : f32
+  // CHECK: fir.call @derivedfunc_callee(%[[buffer]], %[[cst]]) : (!fir.ref<!fir.type<t{x:f32}>>, f32) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // CHECK-BOX: %[[cst:.*]] = constant {{.*}} : f32
+  // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>>
+  // CHECK-BOX: fir.call @derivedfunc_callee(%[[box]], %[[cst]]) : (!fir.box<!fir.type<t{x:f32}>>, f32) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}
+
+func private @derived_lparams_func() -> !fir.type<t2(l1:i32,l2:i32){x:f32}>
+
+// CHECK-LABEL: func @call_derived_lparams_func(
+// CHECK-SAME: %[[buffer:.*]]: !fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>
+// CHECK-BOX-LABEL: func @call_derived_lparams_func(
+// CHECK-BOX-SAME: %[[buffer:.*]]: !fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>
+func @call_derived_lparams_func(%buffer: !fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>) {
+  %l1 = constant 3 : i32
+  %l2 = constant 5 : i32
+  %res = fir.call @derived_lparams_func() : () -> !fir.type<t2(l1:i32,l2:i32){x:f32}>
+  fir.save_result %res to %buffer typeparams %l1, %l2 : !fir.type<t2(l1:i32,l2:i32){x:f32}>, !fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>, i32, i32
+  return
+
+  // CHECK: %[[l1:.*]] = constant 3 : i32
+  // CHECK: %[[l2:.*]] = constant 5 : i32
+  // CHECK: fir.call @derived_lparams_func(%[[buffer]]) : (!fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[l1:.*]] = constant 3 : i32
+  // CHECK-BOX: %[[l2:.*]] = constant 5 : i32
+  // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]] typeparams %[[l1]], %[[l2]] : (!fir.ref<!fir.type<t2(l1:i32,l2:i32){x:f32}>>, i32, i32) -> !fir.box<!fir.type<t2(l1:i32,l2:i32){x:f32}>>
+  // CHECK-BOX: fir.call @derived_lparams_func(%[[box]]) : (!fir.box<!fir.type<t2(l1:i32,l2:i32){x:f32}>>) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}
+
+// CHECK-LABEL: func @call_boxfunc() {
+// CHECK-BOX-LABEL: func @call_boxfunc() {
+func @call_boxfunc() {
+  %buffer = fir.alloca !fir.box<!fir.heap<f64>>
+  %res = fir.call @boxfunc_callee() : () -> !fir.box<!fir.heap<f64>>
+  fir.save_result %res to %buffer: !fir.box<!fir.heap<f64>>, !fir.ref<!fir.box<!fir.heap<f64>>>
+  return
+
+  // CHECK: %[[buffer:.*]] = fir.alloca !fir.box<!fir.heap<f64>>
+  // CHECK: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref<!fir.box<!fir.heap<f64>>>) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.box<!fir.heap<f64>>
+  // CHECK-BOX: fir.call @boxfunc_callee(%[[buffer]]) : (!fir.ref<!fir.box<!fir.heap<f64>>>) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}
+
+func private @chararrayfunc(index, index) -> !fir.array<?x!fir.char<1,?>>
+
+// CHECK-LABEL: func @call_chararrayfunc() {
+// CHECK-BOX-LABEL: func @call_chararrayfunc() {
+func @call_chararrayfunc() {
+  %c100 = constant 100 : index
+  %c50 = constant 50 : index
+  %buffer = fir.alloca !fir.array<?x!fir.char<1,?>>(%c100 : index), %c50
+  %shape = fir.shape %c100 : (index) -> !fir.shape<1>
+  %res = fir.call @chararrayfunc(%c100, %c50) : (index, index) -> !fir.array<?x!fir.char<1,?>>
+  fir.save_result %res to %buffer(%shape) typeparams %c50 : !fir.array<?x!fir.char<1,?>>, !fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index
+  return
+
+  // CHECK: %[[c100:.*]] = constant 100 : index
+  // CHECK: %[[c50:.*]] = constant 50 : index
+  // CHECK: %[[buffer:.*]] = fir.alloca !fir.array<?x!fir.char<1,?>>(%[[c100]] : index), %[[c50]]
+  // CHECK: fir.call @chararrayfunc(%[[buffer]], %[[c100]], %[[c50]]) : (!fir.ref<!fir.array<?x!fir.char<1,?>>>, index, index) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[c100:.*]] = constant 100 : index
+  // CHECK-BOX: %[[c50:.*]] = constant 50 : index
+  // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array<?x!fir.char<1,?>>(%[[c100]] : index), %[[c50]]
+  // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1>
+  // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) typeparams %[[c50]] : (!fir.ref<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index) -> !fir.box<!fir.array<?x!fir.char<1,?>>>
+  // CHECK-BOX: fir.call @chararrayfunc(%[[box]], %[[c100]], %[[c50]]) : (!fir.box<!fir.array<?x!fir.char<1,?>>>, index, index) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}
+
+// ------------------------ Test fir.address_of rewrite ------------------------
+
+func private @takesfuncarray((i32) -> !fir.array<?xf32>)
+
+// CHECK-LABEL: func @test_address_of() {
+// CHECK-BOX-LABEL: func @test_address_of() {
+func @test_address_of() {
+  %0 = fir.address_of(@arrayfunc) : (i32) -> !fir.array<?xf32>
+  fir.call @takesfuncarray(%0) : ((i32) -> !fir.array<?xf32>) -> ()
+  return
+
+  // CHECK: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.ref<!fir.array<?xf32>>, i32) -> ()
+  // CHECK: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.ref<!fir.array<?xf32>>, i32) -> ()) -> ((i32) -> !fir.array<?xf32>)
+  // CHECK: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array<?xf32>) -> ()
+
+  // CHECK-BOX: %[[addrOf:.*]] = fir.address_of(@arrayfunc) : (!fir.box<!fir.array<?xf32>>, i32) -> ()
+  // CHECK-BOX: %[[conv:.*]] = fir.convert %[[addrOf]] : ((!fir.box<!fir.array<?xf32>>, i32) -> ()) -> ((i32) -> !fir.array<?xf32>)
+  // CHECK-BOX: fir.call @takesfuncarray(%[[conv]]) : ((i32) -> !fir.array<?xf32>) -> ()
+
+}
+
+// ----------------------- Test indirect calls rewrite ------------------------
+
+// CHECK-LABEL: func @test_indirect_calls(
+// CHECK-SAME: %[[arg0:.*]]: () -> ()) {
+// CHECK-BOX-LABEL: func @test_indirect_calls(
+// CHECK-BOX-SAME: %[[arg0:.*]]: () -> ()) {
+func @test_indirect_calls(%arg0: () -> ()) {
+  %c100 = constant 100 : index
+  %buffer = fir.alloca !fir.array<?xf32>, %c100
+  %shape = fir.shape %c100 : (index) -> !fir.shape<1>
+  %0 = fir.convert %arg0 : (() -> ()) -> ((index) -> !fir.array<?xf32>)
+  %res = fir.call %0(%c100) : (index) -> !fir.array<?xf32>
+  fir.save_result %res to %buffer(%shape) : !fir.array<?xf32>, !fir.ref<!fir.array<?xf32>>, !fir.shape<1>
+  return
+
+  // CHECK: %[[c100:.*]] = constant 100 : index
+  // CHECK: %[[buffer:.*]] = fir.alloca !fir.array<?xf32>, %[[c100]]
+  // CHECK: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1>
+  // CHECK: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array<?xf32>)
+  // CHECK: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array<?xf32>) -> ((!fir.ref<!fir.array<?xf32>>, index) -> ())
+  // CHECK: fir.call %[[conv]](%[[buffer]], %c100) : (!fir.ref<!fir.array<?xf32>>, index) -> ()
+  // CHECK-NOT: fir.save_result
+
+  // CHECK-BOX: %[[c100:.*]] = constant 100 : index
+  // CHECK-BOX: %[[buffer:.*]] = fir.alloca !fir.array<?xf32>, %[[c100]]
+  // CHECK-BOX: %[[shape:.*]] = fir.shape %[[c100]] : (index) -> !fir.shape<1>
+  // CHECK-BOX: %[[original_conv:.*]] = fir.convert %[[arg0]] : (() -> ()) -> ((index) -> !fir.array<?xf32>)
+  // CHECK-BOX: %[[box:.*]] = fir.embox %[[buffer]](%[[shape]]) : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
+  // CHECK-BOX: %[[conv:.*]] = fir.convert %[[original_conv]] : ((index) -> !fir.array<?xf32>) -> ((!fir.box<!fir.array<?xf32>>, index) -> ())
+  // CHECK-BOX: fir.call %[[conv]](%[[box]], %c100) : (!fir.box<!fir.array<?xf32>>, index) -> ()
+  // CHECK-BOX-NOT: fir.save_result
+}


        


More information about the flang-commits mailing list