[Mlir-commits] [mlir] c3c73e5 - [mlir][openacc] Add canonicalization pattern for acc.host_data

Valentin Clement llvmlistbot at llvm.org
Fri May 12 16:58:19 PDT 2023


Author: Valentin Clement
Date: 2023-05-12T16:57:58-07:00
New Revision: c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e

URL: https://github.com/llvm/llvm-project/commit/c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e
DIFF: https://github.com/llvm/llvm-project/commit/c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e.diff

LOG: [mlir][openacc] Add canonicalization pattern for acc.host_data

Add if condition removal pattern for acc.host_data in a same way as
acc.enter_data, acc.exit_data and acc.update.

The condition is removed from the op if it is a true constant. If
it is a false constant then the region is inlined before the op
and the op is removed.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D150480

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 5370c8fa8feb6..52bbdbb2aee84 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -784,6 +784,7 @@ def OpenACC_HostDataOp : OpenACC_Op<"host_data", [AttrSizedOperandSegments]> {
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 1f73fc795caa3..4013657a4b284 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -292,6 +292,46 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
+                                Region &region, ValueRange blockArgs = {}) {
+  assert(llvm::hasSingleElement(region) && "expected single-region block");
+  Block *block = &region.front();
+  Operation *terminator = block->getTerminator();
+  ValueRange results = terminator->getOperands();
+  rewriter.inlineBlockBefore(block, op, blockArgs);
+  rewriter.replaceOp(op, results);
+  rewriter.eraseOp(terminator);
+}
+
+/// Pattern to remove operation with region that have constant false `ifCond`
+/// and remove the condition from the operation if the `ifCond` is constant
+/// true.
+template <typename OpTy>
+struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Early return if there is no condition.
+    Value ifCond = op.getIfCond();
+    if (!ifCond)
+      return failure();
+
+    IntegerAttr constAttr;
+    if (!matchPattern(ifCond, m_Constant(&constAttr)))
+      return failure();
+    if (constAttr.getInt())
+      rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+    else
+      replaceOpWithRegion(rewriter, op, op.getRegion());
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -386,6 +426,11 @@ LogicalResult acc::HostDataOp::verify() {
   return success();
 }
 
+void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index 91ce9d31a4d0b..9b16db9802754 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -105,3 +105,40 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
 
 // CHECK:  func @testupdateop(%{{.*}}: memref<f32>, [[IFCOND:%.*]]: i1)
 // CHECK:    acc.update if(%{{.*}}) dataOperands(%{{.*}} : memref<f32>)
+
+// -----
+
+func.func @testhostdataop(%a: memref<f32>, %ifCond: i1) -> () {
+  %0 = acc.use_device varPtr(%a : memref<f32>) -> memref<f32>
+  %false = arith.constant false
+  acc.host_data dataOperands(%0 : memref<f32>) if(%false) {
+    acc.loop {
+      acc.yield
+    }
+    acc.loop {
+      acc.yield
+    }
+    acc.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @testhostdataop
+// CHECK-NOT: acc.host_data
+// CHECK: acc.loop
+// CHECK: acc.yield
+// CHECK: acc.loop
+// CHECK: acc.yield
+
+// -----
+
+func.func @testhostdataop(%a: memref<f32>, %ifCond: i1) -> () {
+  %0 = acc.use_device varPtr(%a : memref<f32>) -> memref<f32>
+  %true = arith.constant true
+  acc.host_data dataOperands(%0 : memref<f32>) if(%true) {
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @testhostdataop
+// CHECK: acc.host_data dataOperands(%{{.*}} : memref<f32>) {


        


More information about the Mlir-commits mailing list