[flang-commits] [flang] 0d09120 - [mlir][openacc] Add legalize data pass for compute operation (#80351)

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Feb 5 13:41:49 PST 2024


Author: Valentin Clement
Date: 2024-02-05T13:40:41-08:00
New Revision: 0d091206dd656c2a9d31d6088a4aa6f9c2cc7156

URL: https://github.com/llvm/llvm-project/commit/0d091206dd656c2a9d31d6088a4aa6f9c2cc7156
DIFF: https://github.com/llvm/llvm-project/commit/0d091206dd656c2a9d31d6088a4aa6f9c2cc7156.diff

LOG: [mlir][openacc] Add legalize data pass for compute operation (#80351)

This patch adds a simple pass to replace the uses inside compute operation. It
replaces the `varPtr` values with their corresponding `accPtr` values gathered
through the dataClauseOperands.

private and reductions variables are not included in this pass since they will
normally be replace when they are materialized.

Reland with fix for dependencies

Added: 
    flang/test/Fir/OpenACC/legalize-data.fir
    mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
    mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
    mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
    mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
    mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
    mlir/test/Dialect/OpenACC/legalize-data.mlir

Modified: 
    flang/include/flang/Optimizer/Support/InitFIR.h
    mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Dialect/OpenACC/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 8c47ad3d9f445..b5c41699205f4 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
 #include "mlir/InitAllDialects.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -74,6 +75,7 @@ inline void loadDialects(mlir::MLIRContext &context) {
 /// Register the standard passes we use. This comes from registerAllPasses(),
 /// but is a smaller set since we aren't using many of the passes found there.
 inline void registerMLIRPassesForFortranTools() {
+  mlir::acc::registerOpenACCPasses();
   mlir::registerCanonicalizerPass();
   mlir::registerCSEPass();
   mlir::affine::registerAffineLoopFusionPass();

diff  --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir
new file mode 100644
index 0000000000000..3b8695434e6e4
--- /dev/null
+++ b/flang/test/Fir/OpenACC/legalize-data.fir
@@ -0,0 +1,24 @@
+// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
+
+func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
+  %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+  acc.parallel dataOperands(%1 : !fir.ref<i32>) {
+    %c0_i32 = arith.constant 0 : i32
+    hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+    acc.yield
+  }
+  acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+  return
+}
+
+// CHECK-LABEL: func.func @_QPsub1
+// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
+// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+// CHECK: acc.parallel dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
+// CHECK:   %c0_i32 = arith.constant 0 : i32
+// CHECK:   hlfir.assign %c0{{.*}} to %[[COPYIN]] : i32, !fir.ref<i32>
+// CHECK:   acc.yield
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}

diff  --git a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
index 56ba2976ee5d4..8a4b1c7b196ea 100644
--- a/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
 set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenACC/ACC.td)
 mlir_tablegen(AccCommon.td --gen-directive-decl --directives-dialect=OpenACC)
 add_public_tablegen_target(acc_common_td)

diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..ddbd5839576fc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenACC)
+add_public_tablegen_target(MLIROpenACCPassIncGen)
+
+add_mlir_doc(Passes OpenACCPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
new file mode 100644
index 0000000000000..5a11056cda609
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -0,0 +1,40 @@
+//===- Passes.h - OpenACC Passes Construction and Registration ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
+#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
+#include "mlir/Pass/Pass.h"
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+} // namespace func
+
+namespace acc {
+
+/// Create a pass to replace ssa values in region with device/host values.
+std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES_H

diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
new file mode 100644
index 0000000000000..abbc27765e342
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -0,0 +1,28 @@
+//===-- Passes.td - OpenACC pass definition file -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
+  let summary = "Legalize the data in the compute region";
+  let description = [{
+    This pass replace uses of varPtr in the compute region with their accPtr
+    gathered from the data clause operands.
+  }];
+  let options = [
+    Option<"hostToDevice", "host-to-device", "bool", "true",
+           "Replace varPtr uses with accPtr if true. Replace accPtr uses with "
+           "varPtr if false">
+  ];
+  let constructor = "::mlir::acc::createLegalizeDataInRegion()";
+}
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 28dc3cc23daf2..e28921619fe58 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
@@ -64,6 +65,7 @@ inline void registerAllPasses() {
   registerConversionPasses();
 
   // Dialect passes
+  acc::registerOpenACCPasses();
   affine::registerAffinePasses();
   amdgpu::registerAMDGPUPasses();
   registerAsyncPasses();

diff  --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 27285246ef997..9f57627c321fb 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,20 +1,2 @@
-add_mlir_dialect_library(MLIROpenACCDialect
-  IR/OpenACC.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
-
-  DEPENDS
-  MLIROpenACCOpsIncGen
-  MLIROpenACCEnumsIncGen
-  MLIROpenACCAttributesIncGen
-  MLIROpenACCOpsInterfacesIncGen
-  MLIROpenACCTypeInterfacesIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  MLIRLLVMDialect
-  MLIRMemRefDialect
-  MLIROpenACCMPCommon
-  )
-
+add_subdirectory(IR)
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..b802de165b8f3
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/IR/CMakeLists.txt
@@ -0,0 +1,20 @@
+add_mlir_dialect_library(MLIROpenACCDialect
+  OpenACC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+  DEPENDS
+  MLIROpenACCOpsIncGen
+  MLIROpenACCEnumsIncGen
+  MLIROpenACCAttributesIncGen
+  MLIROpenACCOpsInterfacesIncGen
+  MLIROpenACCTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRMemRefDialect
+  MLIROpenACCMPCommon
+  )
+

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..db3dccd9751cf
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_dialect_library(MLIROpenACCTransforms
+  LegalizeData.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+  DEPENDS
+  MLIROpenACCPassIncGen
+  MLIROpenACCOpsIncGen
+  MLIROpenACCEnumsIncGen
+  MLIROpenACCAttributesIncGen
+  MLIROpenACCOpsInterfacesIncGen
+  MLIROpenACCTypeInterfacesIncGen
+
+  LINK_LIBS PUBLIC
+  MLIROpenACCDialect
+  MLIRFuncDialect
+  MLIRIR
+  MLIRPass
+  MLIRSupport
+  MLIRTransforms
+)

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
new file mode 100644
index 0000000000000..ef44a0ec68d9c
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -0,0 +1,72 @@
+//===- LegalizeData.cpp - -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+  for (auto operand : op.getDataClauseOperands()) {
+    Value varPtr = acc::getVarPtr(operand.getDefiningOp());
+    Value accPtr = acc::getAccPtr(operand.getDefiningOp());
+    if (varPtr && accPtr) {
+      if (hostToDevice)
+        values.push_back({varPtr, accPtr});
+      else
+        values.push_back({accPtr, varPtr});
+    }
+  }
+
+  for (auto p : values)
+    replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+}
+
+struct LegalizeDataInRegion
+    : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+
+  void runOnOperation() override {
+    func::FuncOp funcOp = getOperation();
+    bool replaceHostVsDevice = this->hostToDevice.getValue();
+
+    funcOp.walk([&](Operation *op) {
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+        return;
+
+      if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
+        collectAndReplaceInRegion(parallelOp, replaceHostVsDevice);
+      } else if (auto serialOp = dyn_cast<acc::SerialOp>(*op)) {
+        collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
+      } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
+        collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      }
+    });
+  }
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::acc::createLegalizeDataInRegion() {
+  return std::make_unique<LegalizeDataInRegion>();
+}

diff  --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
new file mode 100644
index 0000000000000..f9857411306c0
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
+// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>, %i : index) {
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.kernels dataOperands(%create : memref<10xf32>) {
+    %ci = memref.load %a[%i] : memref<10xf32>
+    acc.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.kernels dataOperands(%[[CREATE]] : memref<10xf32>) {
+// DEVICE:   %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST:    %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK:   acc.terminator
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel dataOperands(%create : memref<10xf32>) {
+    acc.loop (%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
+// CHECK:   acc.loop (%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[CREATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }


        


More information about the flang-commits mailing list