[flang-commits] [mlir] [flang] [mlir][openacc] Add legalize data pass for compute operation (PR #80351)

via flang-commits flang-commits at lists.llvm.org
Thu Feb 1 13:59:47 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

This patch adds a simple pass to replace the uses inside compute operation. It replaces the `varPtr` values with their corresponding `accPtr` values gathered through the dataClauseOperands.

private and reductions variables are not included in this pass since they will normally be replace when they are materialized. 

---
Full diff: https://github.com/llvm/llvm-project/pull/80351.diff


12 Files Affected:

- (modified) flang/include/flang/Optimizer/Support/InitFIR.h (+2) 
- (added) flang/test/Fir/OpenACC/legalize-data.fir (+24) 
- (modified) mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt (+2) 
- (added) mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt (+5) 
- (added) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h (+40) 
- (added) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+28) 
- (modified) mlir/include/mlir/InitAllPasses.h (+2) 
- (modified) mlir/lib/Dialect/OpenACC/CMakeLists.txt (+2-20) 
- (added) mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt (+20) 
- (added) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+17) 
- (added) mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp (+72) 
- (added) mlir/test/Dialect/OpenACC/legalize-data.mlir (+88) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 8c47ad3d9f445..b5c41699205f4 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
 /// Register the standard passes we use. This comes from registerAllPasses(),
 /// but is a smaller set since we aren't using many of the passes found there.
 inline void registerMLIRPassesForFortranTools() {
+  mlir::acc::registerOpenACCPasses();
   mlir::registerCanonicalizerPass();
   mlir::registerCSEPass();
   mlir::affine::registerAffineLoopFusionPass();
diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir
new file mode 100644
index 0000000000000..3b8695434e6e4
--- /dev/null
+++ b/flang/test/Fir/OpenACC/legalize-data.fir
@@ -0,0 +1,24 @@
+// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
+
+func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+  acc.parallel dataOperands(%1 : !fir.ref<i32>) {
+    %c0_i32 = arith.constant 0 : i32
+    hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+    acc.yield
+  }
+  acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1
+// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
+// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
+// CHECK:   %c0_i32 = arith.constant 0 : i32
+// CHECK:   hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32>
+// CHECK:   acc.yield
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
diff --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
index 56ba2976ee5d4..8a4b1c7b196ea 100644
--- a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
 set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
 mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
 add_public_tablegen_target(acc_common_td)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..ddbd5839576fc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
+add_public_tablegen_target(MLIROpenACCPassIncGen)
+
+add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
new file mode 100644
index 0000000000000..5a11056cda609
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -0,0 +1,40 @@
+//===- Passes.h - OpenACC Passes Construction and Registration ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
+#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
+#include "mlir/Pass/Pass.h"
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+} // namespace func
+
+namespace acc {
+
+/// Create a pass to replace ssa values in region with device/host values.
+std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
new file mode 100644
index 0000000000000..abbc27765e342
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -0,0 +1,28 @@
+//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
+  let summary = "Legalize the data in the compute region";
+  let description = [{
+    This pass replace uses of varPtr in the compute region with their accPtr
+    gathered from the data clause operands.
+  }];
+  let options = [
+    Option<"hostToDevice", "host-to-device", "bool", "true",
+           "Replace varPtr uses with accPtr if true. Replace accPtr uses with "
+           "varPtr if false">
+  ];
+  let constructor = "::mlir::acc::createLegalizeDataInRegion()";
+}
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 28dc3cc23daf2..e28921619fe58 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -64,6 +65,7 @@ inline void registerAllPasses() {
   registerConversionPasses();
 
   // Dialect passes
+  acc::registerOpenACCPasses();
   affine::registerAffinePasses();
   amdgpu::registerAMDGPUPasses();
   registerAsyncPasses();
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 27285246ef997..9f57627c321fb 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,20 +1,2 @@
-add_mlir_dialect_library(MLIROpenACCDialect
-  IR/OpenACC.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
-
-  DEPENDS
-  MLIROpenACCOpsIncGen
-  MLIROpenACCEnumsIncGen
-  MLIROpenACCAttributesIncGen
-  MLIROpenACCOpsInterfacesIncGen
-  MLIROpenACCTypeInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMDialect
-  MLIRMemRefDialect
-  MLIROpenACCMPCommon
-  )
-
+add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..b802de165b8f3
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIROpenACCDialect
+  OpenACC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+  DEPENDS
+  MLIROpenACCOpsIncGen
+  MLIROpenACCEnumsIncGen
+  MLIROpenACCAttributesIncGen
+  MLIROpenACCOpsInterfacesIncGen
+  MLIROpenACCTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRMemRefDialect
+  MLIROpenACCMPCommon
+  )
+
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..78f676362ecfd
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIROpenACCTransforms
+  LegalizeData.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+  DEPENDS
+  MLIROpenACCPassIncGen
+  MLIROpenACCOpsIncGen
+  MLIROpenACCEnumsIncGen
+  MLIROpenACCAttributesIncGen
+  MLIROpenACCOpsInterfacesIncGen
+  MLIROpenACCTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIROpenACCDialect
+)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
new file mode 100644
index 0000000000000..ef44a0ec68d9c
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -0,0 +1,72 @@
+//===- LegalizeData.cpp - -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+  for (auto operand : op.getDataClauseOperands()) {
+    Value varPtr = acc::getVarPtr(operand.getDefiningOp());
+    Value accPtr = acc::getAccPtr(operand.getDefiningOp());
+    if (varPtr && accPtr) {
+      if (hostToDevice)
+        values.push_back({varPtr, accPtr});
+      else
+        values.push_back({accPtr, varPtr});
+    }
+  }
+
+  for (auto p : values)
+    replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+}
+
+struct LegalizeDataInRegion
+    : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    bool replaceHostVsDevice = this->hostToDevice.getValue();
+
+    funcOp.walk([&](Operation *op) {
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+        return;
+
+      if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
+        collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
+      } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
+        collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
+      } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
+        collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      }
+    });
+  }
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::acc::createLegalizeDataInRegion() {
+  return std::make_unique<LegalizeDataInRegion>();
+}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
new file mode 100644
index 0000000000000..f9857411306c0
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
+// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.kernels dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.terminator
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel dataOperands(%create : memref<10xf32>) {
+    acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// CHECK:   acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }

``````````

</details>


https://github.com/llvm/llvm-project/pull/80351


More information about the flang-commits mailing list