[flang-commits] [flang] 460f828 - [flang] Lower statement function

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Mar 18 06:47:24 PDT 2022


Author: Valentin Clement
Date: 2022-03-18T14:47:16+01:00
New Revision: 460f828f09d2c826882e2136654e8bb3057cbc4b

URL: https://github.com/llvm/llvm-project/commit/460f828f09d2c826882e2136654e8bb3057cbc4b
DIFF: https://github.com/llvm/llvm-project/commit/460f828f09d2c826882e2136654e8bb3057cbc4b.diff

LOG: [flang] Lower statement function

This patch adds lowering to suppoert statement functions

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D121990

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    flang/test/Lower/statement-function.f90

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRType.h
    flang/lib/Lower/ConvertExpr.cpp
    flang/lib/Optimizer/Dialect/FIRType.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 9758ba1686b9c..a8bb67980a0be 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -191,6 +191,10 @@ inline bool singleIndirectionLevel(mlir::Type ty) {
 }
 #endif
 
+/// Return true iff `ty` is the type of a POINTER entity or value.
+/// `isa_ref_type()` can be used to distinguish.
+bool isPointerType(mlir::Type ty);
+
 /// Return true iff `ty` is the type of an ALLOCATABLE entity or value.
 bool isAllocatableType(mlir::Type ty);
 

diff  --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 5f3877ea9cd44..6f2d528a6cc67 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -1826,7 +1826,7 @@ class ScalarExprLowering {
   template <typename A>
   ExtValue genFunctionRef(const Fortran::evaluate::FunctionRef<A> &funcRef) {
     if (!funcRef.GetType().has_value())
-      fir::emitFatalError(getLoc(), "internal: a function must have a type");
+      fir::emitFatalError(getLoc(), "a function must have a type");
     mlir::Type resTy = genType(*funcRef.GetType());
     return genProcedureRef(funcRef, {resTy});
   }
@@ -1836,12 +1836,8 @@ class ScalarExprLowering {
   template <typename A>
   ExtValue gen(const Fortran::evaluate::FunctionRef<A> &funcRef) {
     ExtValue retVal = genFunctionRef(funcRef);
-    mlir::Value retValBase = fir::getBase(retVal);
-    if (fir::conformsWithPassByRef(retValBase.getType()))
-      return retVal;
-    auto mem = builder.create<fir::AllocaOp>(getLoc(), retValBase.getType());
-    builder.create<fir::StoreOp>(getLoc(), retValBase, mem);
-    return fir::substBase(retVal, mem.getResult());
+    mlir::Type resultType = converter.genType(toEvExpr(funcRef));
+    return placeScalarValueInMemory(builder, getLoc(), retVal, resultType);
   }
 
   /// helper to detect statement functions
@@ -2457,6 +2453,25 @@ class ScalarExprLowering {
           caller.placeInput(arg, boxStorage);
           continue;
         }
+        if (fir::isPointerType(argTy) &&
+            !Fortran::evaluate::IsObjectPointer(
+                *expr, converter.getFoldingContext())) {
+          // Passing a non POINTER actual argument to a POINTER dummy argument.
+          // Create a pointer of the dummy argument type and assign the actual
+          // argument to it.
+          mlir::Value irBox =
+              builder.createTemporary(loc, fir::unwrapRefType(argTy));
+          // Non deferred parameters will be evaluated on the callee side.
+          fir::MutableBoxValue pointer(irBox,
+                                       /*nonDeferredParams=*/mlir::ValueRange{},
+                                       /*mutableProperties=*/{});
+          Fortran::lower::associateMutableBox(converter, loc, pointer, *expr,
+                                              /*lbounds*/ mlir::ValueRange{},
+                                              stmtCtx);
+          caller.placeInput(arg, irBox);
+          continue;
+        }
+        // Passing a POINTER to a POINTER, or an ALLOCATABLE to an ALLOCATABLE.
         fir::MutableBoxValue mutableBox = genMutableBoxValue(*expr);
         mlir::Value irBox =
             fir::factory::getMutableIRBox(builder, loc, mutableBox);

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 2e35cdcb167bc..daa34038df53a 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -246,6 +246,14 @@ bool hasDynamicSize(mlir::Type t) {
   return false;
 }
 
+bool isPointerType(mlir::Type ty) {
+  if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
+    ty = refTy;
+  if (auto boxTy = ty.dyn_cast<fir::BoxType>())
+    return boxTy.getEleTy().isa<fir::PointerType>();
+  return false;
+}
+
 bool isAllocatableType(mlir::Type ty) {
   if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
     ty = refTy;

diff  --git a/flang/test/Lower/statement-function.f90 b/flang/test/Lower/statement-function.f90
new file mode 100644
index 0000000000000..4f47668bab84d
--- /dev/null
+++ b/flang/test/Lower/statement-function.f90
@@ -0,0 +1,147 @@
+! RUN: bbc -emit-fir -outline-intrinsics %s -o - | FileCheck %s
+
+! Test statement function lowering
+
+! Simple case
+  ! CHECK-LABEL: func @_QPtest_stmt_0(
+  ! CHECK-SAME: %{{.*}}: !fir.ref<f32>{{.*}}) -> f32
+real function test_stmt_0(x)
+real :: x, func, arg
+func(arg) = arg + 0.123456
+
+! CHECK-DAG: %[[x:.*]] = fir.load %arg0
+! CHECK-DAG: %[[cst:.*]] = arith.constant 1.234560e-01
+! CHECK: %[[eval:.*]] = arith.addf %[[x]], %[[cst]]
+! CHECK: fir.store %[[eval]] to %[[resmem:.*]] : !fir.ref<f32>
+test_stmt_0 = func(x)
+
+! CHECK: %[[res:.*]] = fir.load %[[resmem]]
+! CHECK: return %[[res]]
+end function
+
+! Check this is not lowered as a simple macro: e.g. argument is only
+! evaluated once even if it appears in several placed inside the
+! statement function expression 
+! CHECK-LABEL: func @_QPtest_stmt_only_eval_arg_once() -> f32
+real(4) function test_stmt_only_eval_arg_once()
+real(4) :: only_once, x1
+func(x1) = x1 + x1
+! CHECK: %[[x2:.*]] = fir.alloca f32 {adapt.valuebyref}
+! CHECK: %[[x1:.*]] = fir.call @_QPonly_once()
+! Note: using -emit-fir, so the faked pass-by-reference is exposed
+! CHECK: fir.store %[[x1]] to %[[x2]]
+! CHECK: addf %{{.*}}, %{{.*}}
+test_stmt_only_eval_arg_once = func(only_once())
+end function
+
+! Test nested statement function (note that they cannot be recursively
+! nested as per F2018 C1577).
+real function test_stmt_1(x, a)
+real :: y, a, b, foo
+real :: func1, arg1, func2, arg2
+real :: res1, res2
+func1(arg1) = a + foo(arg1)
+func2(arg2) = func1(arg2) + b
+! CHECK-DAG: %[[bmem:.*]] = fir.alloca f32 {{{.*}}uniq_name = "{{.*}}Eb"}
+! CHECK-DAG: %[[res1:.*]] = fir.alloca f32 {{{.*}}uniq_name = "{{.*}}Eres1"}
+! CHECK-DAG: %[[res2:.*]] = fir.alloca f32 {{{.*}}uniq_name = "{{.*}}Eres2"}
+
+b = 5
+
+! CHECK-DAG: %[[cst_8:.*]] = arith.constant 8.000000e+00
+! CHECK-DAG: fir.store %[[cst_8]] to %[[tmp1:.*]] : !fir.ref<f32>
+! CHECK-DAG: %[[foocall1:.*]] = fir.call @_QPfoo(%[[tmp1]])
+! CHECK-DAG: %[[aload1:.*]] = fir.load %arg1
+! CHECK: %[[add1:.*]] = arith.addf %[[aload1]], %[[foocall1]]
+! CHECK: fir.store %[[add1]] to %[[res1]]
+res1 =  func1(8.)
+
+! CHECK-DAG: %[[a2:.*]] = fir.load %arg1
+! CHECK-DAG: %[[foocall2:.*]] = fir.call @_QPfoo(%arg0)
+! CHECK-DAG: %[[add2:.*]] = arith.addf %[[a2]], %[[foocall2]]
+! CHECK-DAG: %[[b:.*]] = fir.load %[[bmem]]
+! CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[b]]
+! CHECK: fir.store %[[add3]] to %[[res2]]
+res2 = func2(x)
+
+! CHECK-DAG: %[[res12:.*]] = fir.load %[[res1]]
+! CHECK-DAG: %[[res22:.*]] = fir.load %[[res2]]
+! CHECK: = arith.addf %[[res12]], %[[res22]] : f32
+test_stmt_1 = res1 + res2
+! CHECK: return %{{.*}} : f32
+end function
+
+
+! Test statement functions with no argument.
+! Test that they are not pre-evaluated.
+! CHECK-LABEL: func @_QPtest_stmt_no_args
+real function test_stmt_no_args(x, y)
+func() = x + y
+! CHECK: addf
+a = func()
+! CHECK: fir.call @_QPfoo_may_modify_xy
+call foo_may_modify_xy(x, y)
+! CHECK: addf
+! CHECK: addf
+test_stmt_no_args = func() + a
+end function
+
+! Test statement function with character arguments
+! CHECK-LABEL: @_QPtest_stmt_character
+integer function test_stmt_character(c, j)
+ integer :: i, j, func, argj
+ character(10) :: c, argc
+ ! CHECK-DAG: %[[unboxed:.*]]:2 = fir.unboxchar %arg0 :
+ ! CHECK-DAG: %[[c10:.*]] = arith.constant 10 :
+ ! CHECK: %[[c10_cast:.*]] = fir.convert %[[c10]] : (i32) -> index
+ ! CHECK: %[[c:.*]] = fir.emboxchar %[[unboxed]]#0, %[[c10_cast]]
+
+ func(argc, argj) = len_trim(argc, 4) + argj
+ ! CHECK: addi %{{.*}}, %{{.*}} : i
+ test_stmt_character = func(c, j)
+end function
+
+! Test statement function with a character actual argument whose
+! length may be 
diff erent than the dummy length (the dummy length
+! must be used inside the statement function).
+! CHECK-LABEL: @_QPtest_stmt_character_with_
diff erent_length(
+! CHECK-SAME: %[[arg0:.*]]: !fir.boxchar<1>
+integer function test_stmt_character_with_
diff erent_length(c)
+ integer :: func, ifoo
+ character(10) :: argc
+ character(*) :: c
+ ! CHECK-DAG: %[[unboxed:.*]]:2 = fir.unboxchar %[[arg0]] :
+ ! CHECK-DAG: %[[c10:.*]] = arith.constant 10 :
+ ! CHECK: %[[c10_cast:.*]] = fir.convert %[[c10]] : (i32) -> index
+ ! CHECK: %[[argc:.*]] = fir.emboxchar %[[unboxed]]#0, %[[c10_cast]]
+ ! CHECK: fir.call @_QPifoo(%[[argc]]) : (!fir.boxchar<1>) -> i32
+ func(argc) = ifoo(argc)
+ test_stmt_character = func(c)
+end function
+
+! CHECK-LABEL: @_QPtest_stmt_character_with_
diff erent_length_2(
+! CHECK-SAME: %[[arg0:.*]]: !fir.boxchar<1>{{.*}}, %[[arg1:.*]]: !fir.ref<i32>
+integer function test_stmt_character_with_
diff erent_length_2(c, n)
+ integer :: func, ifoo
+ character(n) :: argc
+ character(*) :: c
+ ! CHECK: %[[unboxed:.*]]:2 = fir.unboxchar %[[arg0]] :
+ ! CHECK: fir.load %[[arg1]] : !fir.ref<i32>
+ ! CHECK: %[[n:.*]] = fir.load %[[arg1]] : !fir.ref<i32>
+ ! CHECK: %[[n_is_positive:.*]] = arith.cmpi sgt, %[[n]], %c0{{.*}} : i32
+ ! CHECK: %[[len:.*]] = arith.select %[[n_is_positive]], %[[n]], %c0{{.*}} : i32
+ ! CHECK: %[[lenCast:.*]] = fir.convert %[[len]] : (i32) -> index
+ ! CHECK: %[[argc:.*]] = fir.emboxchar %[[unboxed]]#0, %[[lenCast]] : (!fir.ref<!fir.char<1,?>>, index) -> !fir.boxchar<1>
+ ! CHECK: fir.call @_QPifoo(%[[argc]]) : (!fir.boxchar<1>) -> i32
+ func(argc) = ifoo(argc)
+ test_stmt_character = func(c)
+end function
+
+! issue #247
+! CHECK-LABEL: @_QPbug247
+subroutine bug247(r)
+I(R) = R
+! CHECK: fir.call {{.*}}OutputInteger
+PRINT *, I(2.5)
+! CHECK: fir.call {{.*}}EndIo
+END subroutine bug247


        


More information about the flang-commits mailing list