[flang-commits] [flang] [flang][openacc] Do not loose attributes on folding (PR #80516)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Feb 2 16:48:52 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/80516

>From b62ee801d5e6174648dd06e0f07014197c34358d Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 2 Feb 2024 16:15:02 -0800
Subject: [PATCH 1/2] [flang][openacc] Do not loose attributes on folding

---
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 17 +++++++-
 .../Fir/OpenACC/propagate-attr-folding.fir    | 39 +++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/Fir/OpenACC/propagate-attr-folding.fir

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index f826f2566b897..96a676149203d 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -20,6 +20,7 @@
 #include "flang/Optimizer/Support/Utils.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -36,6 +37,18 @@ namespace {
 #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc"
 } // namespace
 
+static void propagateAttributes(mlir::Operation *fromOp,
+                                mlir::Operation *toOp) {
+  if (!fromOp || !toOp)
+    return;
+
+  for (mlir::NamedAttribute attr : fromOp->getAttrs()) {
+    if (attr.getName().str().rfind(
+            mlir::acc::OpenACCDialect::getDialectNamespace(), 0) == 0)
+      toOp->setAttr(attr.getName().str(), attr.getValue());
+  }
+}
+
 /// Return true if a sequence type is of some incomplete size or a record type
 /// is malformed or contains an incomplete sequence type. An incomplete sequence
 /// type is one with more unknown extents in the type than have been provided
@@ -626,8 +639,10 @@ mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) {
   if (auto *v = getVal().getDefiningOp()) {
     if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) {
       // Fold only if not sliced
-      if (!box.getSlice() && box.getMemref().getType() == getType())
+      if (!box.getSlice() && box.getMemref().getType() == getType()) {
+        propagateAttributes(getOperation(), box.getMemref().getDefiningOp());
         return box.getMemref();
+      }
     }
     if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v))
       if (box.getMemref().getType() == getType())
diff --git a/flang/test/Fir/OpenACC/propagate-attr-folding.fir b/flang/test/Fir/OpenACC/propagate-attr-folding.fir
new file mode 100644
index 0000000000000..f84269f5fef26
--- /dev/null
+++ b/flang/test/Fir/OpenACC/propagate-attr-folding.fir
@@ -0,0 +1,39 @@
+// RUN: fir-opt %s --opt-bufferization | FileCheck %s 
+
+// Check that OpenACC attributes are propagated to the defining operations when
+// fir.box_addr is folded in bufferization optimization.
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n1"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n2"}) {
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = fir.declare %arg1 {uniq_name = "_QFsub1En1"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %1 = fir.declare %arg2 {uniq_name = "_QFsub1En2"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %2 = fir.load %0 : !fir.ref<i32>
+  %3 = fir.convert %2 : (i32) -> index
+  %4 = arith.cmpi sgt, %3, %c0 : index
+  %5 = arith.select %4, %3, %c0 : index
+  %6 = fir.load %1 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %5, %9 : (index, index) -> !fir.shape<2>
+  %11 = fir.declare %arg0(%10) {acc.declare = #acc.declare<dataClause =  acc_present>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<?x?xf32>>
+  %12 = fir.embox %11(%10) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
+  %13:3 = fir.box_dims %12, %c0 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+  %14 = arith.subi %13#1, %c1 : index
+  %15 = acc.bounds lowerbound(%c0 : index) upperbound(%14 : index) extent(%13#1 : index) stride(%13#2 : index) startIdx(%c1 : index) {strideInBytes = true}
+  %16 = arith.muli %13#2, %13#1 : index
+  %17:3 = fir.box_dims %12, %c1 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
+  %18 = arith.subi %17#1, %c1 : index
+  %19 = acc.bounds lowerbound(%c0 : index) upperbound(%18 : index) extent(%17#1 : index) stride(%16 : index) startIdx(%c1 : index) {strideInBytes = true}
+  %20 = acc.present varPtr(%11 : !fir.ref<!fir.array<?x?xf32>>) bounds(%15, %19) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
+  %21 = acc.declare_enter dataOperands(%20 : !fir.ref<!fir.array<?x?xf32>>)
+  acc.declare_exit token(%21)
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "a"}
+// CHECK: %[[DECL:.*]] = fir.declare %[[ARG0]](%{{.*}}) {acc.declare = #acc.declare<dataClause =  acc_present>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<?x?xf32>>
+// CHECK: %[[PRES:.*]] = acc.present varPtr(%[[DECL]] : !fir.ref<!fir.array<?x?xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
+// CHECK: %{{.*}} = acc.declare_enter dataOperands(%[[PRES]] : !fir.ref<!fir.array<?x?xf32>>)

>From b2031ccea9dba83c2ade67fb3d79b488f03ec0a7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 2 Feb 2024 16:48:19 -0800
Subject: [PATCH 2/2] Use starts_with and update tests

---
 flang/lib/Optimizer/Dialect/FIROps.cpp            | 4 ++--
 flang/test/Fir/OpenACC/propagate-attr-folding.fir | 9 +++++----
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 96a676149203d..777cbf6b4bb26 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -43,8 +43,8 @@ static void propagateAttributes(mlir::Operation *fromOp,
     return;
 
   for (mlir::NamedAttribute attr : fromOp->getAttrs()) {
-    if (attr.getName().str().rfind(
-            mlir::acc::OpenACCDialect::getDialectNamespace(), 0) == 0)
+    if (attr.getName().getValue().starts_with(
+            mlir::acc::OpenACCDialect::getDialectNamespace()))
       toOp->setAttr(attr.getName().str(), attr.getValue());
   }
 }
diff --git a/flang/test/Fir/OpenACC/propagate-attr-folding.fir b/flang/test/Fir/OpenACC/propagate-attr-folding.fir
index f84269f5fef26..99ac7516690e2 100644
--- a/flang/test/Fir/OpenACC/propagate-attr-folding.fir
+++ b/flang/test/Fir/OpenACC/propagate-attr-folding.fir
@@ -17,7 +17,7 @@ func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "a"},
   %8 = arith.cmpi sgt, %7, %c0 : index
   %9 = arith.select %8, %7, %c0 : index
   %10 = fir.shape %5, %9 : (index, index) -> !fir.shape<2>
-  %11 = fir.declare %arg0(%10) {acc.declare = #acc.declare<dataClause =  acc_present>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<?x?xf32>>
+  %11 = fir.declare %arg0(%10) {uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.ref<!fir.array<?x?xf32>>
   %12 = fir.embox %11(%10) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
   %13:3 = fir.box_dims %12, %c0 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
   %14 = arith.subi %13#1, %c1 : index
@@ -26,9 +26,10 @@ func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "a"},
   %17:3 = fir.box_dims %12, %c1 : (!fir.box<!fir.array<?x?xf32>>, index) -> (index, index, index)
   %18 = arith.subi %17#1, %c1 : index
   %19 = acc.bounds lowerbound(%c0 : index) upperbound(%18 : index) extent(%17#1 : index) stride(%16 : index) startIdx(%c1 : index) {strideInBytes = true}
-  %20 = acc.present varPtr(%11 : !fir.ref<!fir.array<?x?xf32>>) bounds(%15, %19) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
-  %21 = acc.declare_enter dataOperands(%20 : !fir.ref<!fir.array<?x?xf32>>)
-  acc.declare_exit token(%21)
+  %20 = fir.box_addr %12 {acc.declare = #acc.declare<dataClause =  acc_present>} : (!fir.box<!fir.array<?x?xf32>>) -> !fir.ref<!fir.array<?x?xf32>>
+  %21 = acc.present varPtr(%20 : !fir.ref<!fir.array<?x?xf32>>) bounds(%15, %19) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
+  %22 = acc.declare_enter dataOperands(%21 : !fir.ref<!fir.array<?x?xf32>>)
+  acc.declare_exit token(%22)
   return
 }
 



More information about the flang-commits mailing list