[flang-commits] [flang] 68bcd64 - [flang][openacc] Add data operands conversion from FIR

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Apr 10 12:18:11 PDT 2023


Author: Valentin Clement
Date: 2023-04-10T12:18:05-07:00
New Revision: 68bcd647c9c006b707bc9a675a874658cd085d13

URL: https://github.com/llvm/llvm-project/commit/68bcd647c9c006b707bc9a675a874658cd085d13
DIFF: https://github.com/llvm/llvm-project/commit/68bcd647c9c006b707bc9a675a874658cd085d13.diff

LOG: [flang][openacc] Add data operands conversion from FIR

This patch revive an old PR attempt [1] to perform the
data operands conversion needed for translation to LLVMIR.

This is currently not supporting box/class type since they will
normally not reach this pass when the proposed change in this RFC [2]
are implemented.

[1] https://github.com/flang-compiler/f18-llvm-project/pull/915
[2] https://discourse.llvm.org/t/rfc-openacc-dialect-data-operation-improvements/69825/2

Depends on D147824

Reviewed By: PeteSteinfeld, razvanlupusoru

Differential Revision: https://reviews.llvm.org/D147825

Added: 
    flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
    flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 8af14f8013abb..0d3bd325f001b 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -43,6 +43,7 @@ namespace fir {
 #define GEN_PASS_DECL_SIMPLIFYREGIONLITE
 #define GEN_PASS_DECL_ALGEBRAICSIMPLIFICATION
 #define GEN_PASS_DECL_POLYMORPHICOPCONVERSION
+#define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
 std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
@@ -70,6 +71,7 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
 std::unique_ptr<mlir::Pass>
 createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
 std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
+std::unique_ptr<mlir::Pass> createOpenACCDataOperandConversionPass();
 
 // declarative passes
 #define GEN_PASS_REGISTRATION

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index b8ad0243b6af7..4ac85c4f829d5 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -284,5 +284,14 @@ def PolymorphicOpConversion : Pass<"fir-polymorphic-op", "::mlir::func::FuncOp">
   ];
 }
 
-  
+def OpenACCDataOperandConversion : Pass<"fir-openacc-data-operand-conversion", "::mlir::func::FuncOp"> {
+  let summary = "Convert the FIR operands in OpenACC ops to LLVM dialect";
+  let dependentDialects = ["mlir::LLVM::LLVMDialect"];
+  let options = [
+    Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+           /*default=*/"true", "Generate LLVM IR using opaque pointers "
+           "instead of typed pointers">,
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index ca690c341bb26..6ca406629f1b1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_flang_library(FIRTransforms
   SimplifyIntrinsics.cpp
   AddDebugFoundation.cpp
   PolymorphicOpConversion.cpp
+  OpenACC/OpenACCDataOperandConversion.cpp
 
   DEPENDS
   FIRDialect
@@ -22,6 +23,7 @@ add_flang_library(FIRTransforms
 
   LINK_LIBS
   FIRBuilder
+  FIRCodeGen
   FIRDialect
   FIRDialectSupport
   FIRSupport

diff  --git a/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
new file mode 100644
index 0000000000000..f6d6524da548f
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
@@ -0,0 +1,180 @@
+//===- OpenACCDataOperandConversion.cpp - OpenACC data operand conversion -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_OPENACCDATAOPERANDCONVERSION
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-openacc-conversion"
+#include "../CodeGen/TypeConverter.h"
+
+using namespace fir;
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename Op>
+class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
+  using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &builder) const override {
+    Location loc = op.getLoc();
+    fir::LLVMTypeConverter &converter =
+        *static_cast<fir::LLVMTypeConverter *>(this->getTypeConverter());
+
+    unsigned numDataOperands = op.getNumDataOperands();
+
+    // Keep the non data operands without modification.
+    auto nonDataOperands = adaptor.getOperands().take_front(
+        adaptor.getOperands().size() - numDataOperands);
+    SmallVector<Value> convertedOperands;
+    convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
+
+    // Go over the data operand and legalize them for translation.
+    for (unsigned idx = 0; idx < numDataOperands; ++idx) {
+      Value originalDataOperand = op.getDataOperand(idx);
+      if (auto refTy =
+              originalDataOperand.getType().dyn_cast<fir::ReferenceType>()) {
+        if (refTy.getEleTy().isa<fir::BaseBoxType>())
+          return builder.notifyMatchFailure(op, "BaseBoxType not supported");
+        mlir::Type convertedType =
+            converter.convertType(refTy).cast<mlir::LLVM::LLVMPointerType>();
+        mlir::Value castedOperand =
+            builder
+                .create<mlir::UnrealizedConversionCastOp>(loc, convertedType,
+                                                          originalDataOperand)
+                .getResult(0);
+        convertedOperands.push_back(castedOperand);
+      } else {
+        // Type not supported.
+        return builder.notifyMatchFailure(op, "expecting a reference type");
+      }
+    }
+
+    builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
+                                   op.getOperation()->getAttrs());
+
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+struct OpenACCDataOperandConversion
+    : public fir::impl::OpenACCDataOperandConversionBase<
+          OpenACCDataOperandConversion> {
+  using Base::Base;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void OpenACCDataOperandConversion::runOnOperation() {
+  auto op = getOperation();
+  auto *context = op.getContext();
+
+  // Convert to OpenACC operations with LLVM IR dialect
+  RewritePatternSet patterns(context);
+  LowerToLLVMOptions options(context);
+  options.useOpaquePointers = useOpaquePointers;
+  fir::LLVMTypeConverter converter(
+      op.getOperation()->getParentOfType<mlir::ModuleOp>(), true);
+  patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
+  patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
+  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
+  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
+  patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
+
+  ConversionTarget target(*context);
+  target.addLegalDialect<fir::FIROpsDialect>();
+  target.addLegalDialect<LLVM::LLVMDialect>();
+  target.addLegalOp<UnrealizedConversionCastOp>();
+
+  auto allDataOperandsAreConverted = [](ValueRange operands) {
+    for (Value operand : operands) {
+      if (!operand.getType().isa<LLVM::LLVMPointerType>())
+        return false;
+    }
+    return true;
+  };
+
+  target.addDynamicallyLegalOp<acc::DataOp>(
+      [allDataOperandsAreConverted](acc::DataOp op) {
+        return allDataOperandsAreConverted(op.getCopyOperands()) &&
+               allDataOperandsAreConverted(op.getCopyinOperands()) &&
+               allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
+               allDataOperandsAreConverted(op.getCopyoutOperands()) &&
+               allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
+               allDataOperandsAreConverted(op.getCreateOperands()) &&
+               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
+               allDataOperandsAreConverted(op.getNoCreateOperands()) &&
+               allDataOperandsAreConverted(op.getPresentOperands()) &&
+               allDataOperandsAreConverted(op.getDeviceptrOperands()) &&
+               allDataOperandsAreConverted(op.getAttachOperands());
+      });
+
+  target.addDynamicallyLegalOp<acc::EnterDataOp>(
+      [allDataOperandsAreConverted](acc::EnterDataOp op) {
+        return allDataOperandsAreConverted(op.getCopyinOperands()) &&
+               allDataOperandsAreConverted(op.getCreateOperands()) &&
+               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
+               allDataOperandsAreConverted(op.getAttachOperands());
+      });
+
+  target.addDynamicallyLegalOp<acc::ExitDataOp>(
+      [allDataOperandsAreConverted](acc::ExitDataOp op) {
+        return allDataOperandsAreConverted(op.getCopyoutOperands()) &&
+               allDataOperandsAreConverted(op.getDeleteOperands()) &&
+               allDataOperandsAreConverted(op.getDetachOperands());
+      });
+
+  target.addDynamicallyLegalOp<acc::ParallelOp>(
+      [allDataOperandsAreConverted](acc::ParallelOp op) {
+        return allDataOperandsAreConverted(op.getReductionOperands()) &&
+               allDataOperandsAreConverted(op.getCopyOperands()) &&
+               allDataOperandsAreConverted(op.getCopyinOperands()) &&
+               allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
+               allDataOperandsAreConverted(op.getCopyoutOperands()) &&
+               allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
+               allDataOperandsAreConverted(op.getCreateOperands()) &&
+               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
+               allDataOperandsAreConverted(op.getNoCreateOperands()) &&
+               allDataOperandsAreConverted(op.getPresentOperands()) &&
+               allDataOperandsAreConverted(op.getDevicePtrOperands()) &&
+               allDataOperandsAreConverted(op.getAttachOperands()) &&
+               allDataOperandsAreConverted(op.getGangPrivateOperands()) &&
+               allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
+      });
+
+  target.addDynamicallyLegalOp<acc::UpdateOp>(
+      [allDataOperandsAreConverted](acc::UpdateOp op) {
+        return allDataOperandsAreConverted(op.getHostOperands()) &&
+               allDataOperandsAreConverted(op.getDeviceOperands());
+      });
+
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    signalPassFailure();
+}

diff  --git a/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
new file mode 100644
index 0000000000000..12c4c7737a1a0
--- /dev/null
+++ b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
@@ -0,0 +1,84 @@
+// RUN: fir-opt -fir-openacc-data-operand-conversion='use-opaque-pointers=1' -split-input-file %s | FileCheck %s
+
+func.func @_QQsub1() attributes {fir.bindc_name = "arr"} {
+  %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+  acc.data copy(%0 : !fir.ref<!fir.array<10xf32>>) {
+    acc.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @_QQsub1() attributes {fir.bindc_name = "arr"} {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
+// CHECK: acc.data copy(%[[CAST]] : !llvm.ptr<array<10 x f32>>)
+
+// -----
+
+func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} {
+  %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+  acc.enter_data copyin(%0 : !fir.ref<!fir.array<10xf32>>)
+  acc.exit_data copyout(%0 : !fir.ref<!fir.array<10xf32>>)
+  return
+}
+
+// CHECK-LABEL: func.func @_QQsub_enter_exit() attributes {fir.bindc_name = "a"} {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
+// CHECK: acc.enter_data copyin(%[[CAST0]] : !llvm.ptr<array<10 x f32>>)
+// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
+// CHECK: acc.exit_data copyout(%[[CAST1]] : !llvm.ptr<array<10 x f32>>)
+
+// -----
+
+func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
+  %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+  acc.update device(%0 : !fir.ref<!fir.array<10xf32>>)
+  return
+}
+
+// CHECK-LABEL: func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
+// CHECK: acc.update device(%[[CAST]] : !llvm.ptr<array<10 x f32>>)
+
+// -----
+
+func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} {
+  %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+  %1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
+  acc.parallel copyin(%0: !fir.ref<!fir.array<10xf32>>) {
+    acc.loop {
+      %c1_i32 = arith.constant 1 : i32
+      %2 = fir.convert %c1_i32 : (i32) -> index
+      %c10_i32 = arith.constant 10 : i32
+      %3 = fir.convert %c10_i32 : (i32) -> index
+      %c1 = arith.constant 1 : index
+      %4 = fir.convert %2 : (index) -> i32
+      %5:2 = fir.do_loop %arg0 = %2 to %3 step %c1 iter_args(%arg1 = %4) -> (index, i32) {
+        fir.store %arg1 to %1 : !fir.ref<i32>
+        %6 = fir.load %1 : !fir.ref<i32>
+        %7 = fir.convert %6 : (i32) -> f32
+        %c10_i64 = arith.constant 10 : i64
+        %c1_i64 = arith.constant 1 : i64
+        %8 = arith.subi %c10_i64, %c1_i64 : i64
+        %9 = fir.coordinate_of %0, %8 : (!fir.ref<!fir.array<10xf32>>, i64) -> !fir.ref<f32>
+        fir.store %7 to %9 : !fir.ref<f32>
+        %10 = arith.addi %arg0, %c1 : index
+        %11 = fir.convert %c1 : (index) -> i32
+        %12 = fir.load %1 : !fir.ref<i32>
+        %13 = arith.addi %12, %11 : i32
+        fir.result %10, %13 : index, i32
+      }
+      fir.store %5#1 to %1 : !fir.ref<i32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} {
+// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
+// CHECK: acc.parallel copyin(%[[CAST]]: !llvm.ptr<array<10 x f32>>) {


        


More information about the flang-commits mailing list