[Mlir-commits] [mlir] fb5b590 - [mlir][openacc] Add conversion for if operand to scf.if for standalone data operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 7 09:10:20 PDT 2021


Author: Valentin Clement
Date: 2021-06-07T12:10:03-04:00
New Revision: fb5b590b5e18796bf850170946f15fc10ab9394a

URL: https://github.com/llvm/llvm-project/commit/fb5b590b5e18796bf850170946f15fc10ab9394a
DIFF: https://github.com/llvm/llvm-project/commit/fb5b590b5e18796bf850170946f15fc10ab9394a.diff

LOG: [mlir][openacc] Add conversion for if operand to scf.if for standalone data operation

This patch convert the if condition on standalone data operation such as acc.update,
acc.enter_data and acc.exit_data to a scf.if with the operation in the if region.
It removes the operation when the if condition is constant and false. It removes the
the condition if it is contant and true.

Conversion to scf.if is done in order to use the translation to LLVM IR dialect out of the box.
Not sure this is the best approach or we should perform this during the translation from OpenACC
to LLVM IR dialect. Any thoughts welcome.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D103325

Added: 
    mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h
    mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt
    mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
    mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/PassDetail.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h
new file mode 100644
index 000000000000..50d6bd880cde
--- /dev/null
+++ b/mlir/include/mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h
@@ -0,0 +1,28 @@
+//===- ConvertOpenACCToSCF.h - OpenACC conversion pass entrypoint ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
+#define MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H
+
+#include <memory>
+
+namespace mlir {
+class ModuleOp;
+template <typename T>
+class OperationPass;
+class RewritePatternSet;
+
+/// Collect the patterns to convert from the OpenACC dialect to OpenACC with
+/// SCF dialect.
+void populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert the OpenACC dialect into the LLVMIR dialect.
+std::unique_ptr<OperationPass<ModuleOp>> createConvertOpenACCToSCFPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_OPENACCTOSCF_CONVERTOPENACCTOSCF_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 0fd7fdd131b9..d75e617a902c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,6 +23,7 @@
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
+#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
 #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a659c20e6530..508a0084015f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -255,6 +255,16 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
   let dependentDialects = ["StandardOpsDialect", "vector::VectorDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// OpenACCToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertOpenACCToSCF : Pass<"convert-openacc-to-scf", "ModuleOp"> {
+  let summary = "Convert the OpenACC ops to OpenACC with SCF dialect";
+  let constructor = "mlir::createConvertOpenACCToSCFPass()";
+  let dependentDialects = ["scf::SCFDialect", "acc::OpenACCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // OpenACCToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 3b6756798b2c..61011ee0aa4b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -503,13 +503,13 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
     ```
   }];
 
-  let arguments = (ins Optional<IntOrIndex>:$asyncOperand,
+  let arguments = (ins Optional<I1>:$ifCond,
+                       Optional<IntOrIndex>:$asyncOperand,
                        Optional<IntOrIndex>:$waitDevnum,
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$async,
                        UnitAttr:$wait,
                        Variadic<IntOrIndex>:$deviceTypeOperands,
-                       Optional<I1>:$ifCond,
                        Variadic<AnyType>:$hostOperands,
                        Variadic<AnyType>:$deviceOperands,
                        UnitAttr:$ifPresent);

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e61e80423bcd..b89be2a56941 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(LinalgToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(MathToLibm)
 add_subdirectory(OpenACCToLLVM)
+add_subdirectory(OpenACCToSCF)
 add_subdirectory(OpenMPToLLVM)
 add_subdirectory(PDLToPDLInterp)
 add_subdirectory(SCFToGPU)

diff  --git a/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt
new file mode 100644
index 000000000000..aab569371275
--- /dev/null
+++ b/mlir/lib/Conversion/OpenACCToSCF/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIROpenACCToSCF
+  OpenACCToSCF.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenACCToSCF
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIROpenACC
+  MLIRTransforms
+  MLIRSCF
+  )

diff  --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
new file mode 100644
index 000000000000..cc92982b956a
--- /dev/null
+++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
@@ -0,0 +1,90 @@
+//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Pattern to transform the `ifCond` on operation without region into a scf.if
+/// and move the operation into the `then` region.
+template <typename OpTy>
+class ExpandIfCondition : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Early exit if there is no condition.
+    if (!op.ifCond())
+      return success();
+
+    // Condition is not a constant.
+    if (!op.ifCond().template getDefiningOp<ConstantOp>()) {
+      auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
+                                             op.ifCond(), false);
+      rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
+      auto thenBodyBuilder = ifOp.getThenBodyBuilder();
+      thenBodyBuilder.setListener(rewriter.getListener());
+      thenBodyBuilder.clone(*op.getOperation());
+      rewriter.eraseOp(op);
+    }
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
+  patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
+  patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
+}
+
+namespace {
+struct ConvertOpenACCToSCFPass
+    : public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertOpenACCToSCFPass::runOnOperation() {
+  auto op = getOperation();
+  auto *context = op.getContext();
+
+  RewritePatternSet patterns(context);
+  ConversionTarget target(*context);
+  populateOpenACCToSCFConversionPatterns(patterns);
+
+  target.addLegalDialect<scf::SCFDialect>();
+  target.addLegalDialect<acc::OpenACCDialect>();
+
+  target.addDynamicallyLegalOp<acc::EnterDataOp>(
+      [](acc::EnterDataOp op) { return !op.ifCond(); });
+
+  target.addDynamicallyLegalOp<acc::ExitDataOp>(
+      [](acc::ExitDataOp op) { return !op.ifCond(); });
+
+  target.addDynamicallyLegalOp<acc::UpdateOp>(
+      [](acc::UpdateOp op) { return !op.ifCond(); });
+
+  if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
+  return std::make_unique<ConvertOpenACCToSCFPass>();
+}

diff  --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 7ff8c903348e..5287c6c1490b 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -19,6 +19,10 @@ class StandardOpsDialect;
 template <typename ConcreteDialect>
 void registerDialect(DialectRegistry &registry);
 
+namespace acc {
+class OpenACCDialect;
+} // end namespace acc
+
 namespace complex {
 class ComplexDialect;
 } // end namespace complex

diff  --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
new file mode 100644
index 000000000000..29d2bb66e728
--- /dev/null
+++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
+
+func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+  return
+}
+
+// CHECK:      func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:        scf.if [[IFCOND]] {
+// CHECK-NEXT:     acc.enter_data create(%{{.*}} : memref<10xf32>)
+// CHECK-NEXT:   }
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+  return
+}
+
+// CHECK:      func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:        scf.if [[IFCOND]] {
+// CHECK-NEXT:     acc.exit_data delete(%{{.*}} : memref<10xf32>)
+// CHECK-NEXT:   }
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  return
+}
+
+// CHECK:      func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:        scf.if [[IFCOND]] {
+// CHECK-NEXT:     acc.update host(%{{.*}} : memref<10xf32>)
+// CHECK-NEXT:   }


        


More information about the Mlir-commits mailing list