[flang-commits] [flang] 0501102 - [flang][LoopVersioning] support fir.declare
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Fri Aug 18 02:53:52 PDT 2023
Author: Tom Eccles
Date: 2023-08-18T09:51:22Z
New Revision: 05011024fd0a27f54e3fd9c3fb7c42f6a86ab391
URL: https://github.com/llvm/llvm-project/commit/05011024fd0a27f54e3fd9c3fb7c42f6a86ab391
DIFF: https://github.com/llvm/llvm-project/commit/05011024fd0a27f54e3fd9c3fb7c42f6a86ab391.diff
LOG: [flang][LoopVersioning] support fir.declare
When FIR comes from HLFIR, there will be a fir.declare operation between
the source and the usage of each source variable (and some temporary
allocations). This pass needs to be able to follow these so that it can
still transform loops when HLFIR is used, otherwise it mistakenly
assumes these values are not function arguments.
More work is needed after this patch to fully support HLFIR, because the
generated code tends to use fir.array_coor instead of fir.coordinate_of.
Differential Revision: https://reviews.llvm.org/D157964
Added:
Modified:
flang/lib/Optimizer/Transforms/LoopVersioning.cpp
flang/test/Transforms/loop-versioning.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index 511c80673f21dd..63786e377bb4a5 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -94,6 +94,16 @@ static fir::SequenceType getAsSequenceType(mlir::Value *v) {
return argTy.dyn_cast<fir::SequenceType>();
}
+/// if a value comes from a fir.declare, follow it to the original source,
+/// otherwise return the value
+static mlir::Value unwrapFirDeclare(mlir::Value val) {
+ // fir.declare is for source code variables. We don't have declares of
+ // declares
+ if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
+ return declare.getMemref();
+ return val;
+}
+
void LoopVersioningPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::func::FuncOp func = getOperation();
@@ -154,9 +164,9 @@ void LoopVersioningPass::runOnOperation() {
// to it later.
if (op->getParentOfType<fir::DoLoopOp>() != loop)
return;
- const mlir::Value &operand = op->getOperand(0);
+ mlir::Value operand = op->getOperand(0);
for (auto a : argsOfInterest) {
- if (*a.arg == operand) {
+ if (*a.arg == unwrapFirDeclare(operand)) {
// Only add if it's not already in the list.
if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
return it.arg == a.arg;
@@ -244,7 +254,7 @@ void LoopVersioningPass::runOnOperation() {
// arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
// where stride is the distance between elements in the dimensions
// 0, 1 and 2 or x, y and z.
- if (coop->getOperand(0) == *arg.arg &&
+ if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
coop->getOperands().size() >= 2) {
builder.setInsertionPoint(coop);
mlir::Value totalIndex;
diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index 6fc8eb852c1cf3..8659273f117502 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -13,6 +13,7 @@
// end subroutine sum1d
module {
func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
+ %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
%1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
%cst = arith.constant 0.000000e+00 : f64
@@ -30,7 +31,7 @@ module {
%9 = fir.convert %8 : (i32) -> i64
%c1_i64 = arith.constant 1 : i64
%10 = arith.subi %9, %c1_i64 : i64
- %11 = fir.coordinate_of %arg0, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+ %11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
%12 = fir.load %11 : !fir.ref<f64>
%13 = arith.addf %7, %12 fastmath<contract> : f64
fir.store %13 to %1 : !fir.ref<f64>
@@ -47,6 +48,7 @@ module {
// Note this only checks the expected transformation, not the entire generated code:
// CHECK-LABEL: func.func @sum1d(
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
+// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
// CHECK: %[[SIZE:.*]] = arith.constant 8 : index
@@ -62,7 +64,7 @@ module {
// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
// CHECK: } else {
// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
-// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
+// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
// CHECK: fir.result %{{.*}}, %{{.*}}
// CHECK: }
More information about the flang-commits
mailing list