[Mlir-commits] [mlir] c3c73e5 - [mlir][openacc] Add canonicalization pattern for acc.host_data
Valentin Clement
llvmlistbot at llvm.org
Fri May 12 16:58:19 PDT 2023
Author: Valentin Clement
Date: 2023-05-12T16:57:58-07:00
New Revision: c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e
URL: https://github.com/llvm/llvm-project/commit/c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e
DIFF: https://github.com/llvm/llvm-project/commit/c3c73e5d40dc4e7815bd9465723f0c8e3bdb3d5e.diff
LOG: [mlir][openacc] Add canonicalization pattern for acc.host_data
Add if condition removal pattern for acc.host_data in a same way as
acc.enter_data, acc.exit_data and acc.update.
The condition is removed from the op if it is a true constant. If
it is a false constant then the region is inlined before the op
and the op is removed.
Reviewed By: vzakhari
Differential Revision: https://reviews.llvm.org/D150480
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 5370c8fa8feb6..52bbdbb2aee84 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -784,6 +784,7 @@ def OpenACC_HostDataOp : OpenACC_Op<"host_data", [AttrSizedOperandSegments]> {
}];
let hasVerifier = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 1f73fc795caa3..4013657a4b284 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -292,6 +292,46 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
+ Region ®ion, ValueRange blockArgs = {}) {
+ assert(llvm::hasSingleElement(region) && "expected single-region block");
+ Block *block = ®ion.front();
+ Operation *terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.inlineBlockBefore(block, op, blockArgs);
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+}
+
+/// Pattern to remove operation with region that have constant false `ifCond`
+/// and remove the condition from the operation if the `ifCond` is constant
+/// true.
+template <typename OpTy>
+struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Early return if there is no condition.
+ Value ifCond = op.getIfCond();
+ if (!ifCond)
+ return failure();
+
+ IntegerAttr constAttr;
+ if (!matchPattern(ifCond, m_Constant(&constAttr)))
+ return failure();
+ if (constAttr.getInt())
+ rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
+ else
+ replaceOpWithRegion(rewriter, op, op.getRegion());
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -386,6 +426,11 @@ LogicalResult acc::HostDataOp::verify() {
return success();
}
+void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// LoopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index 91ce9d31a4d0b..9b16db9802754 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -105,3 +105,40 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
// CHECK: func @testupdateop(%{{.*}}: memref<f32>, [[IFCOND:%.*]]: i1)
// CHECK: acc.update if(%{{.*}}) dataOperands(%{{.*}} : memref<f32>)
+
+// -----
+
+func.func @testhostdataop(%a: memref<f32>, %ifCond: i1) -> () {
+ %0 = acc.use_device varPtr(%a : memref<f32>) -> memref<f32>
+ %false = arith.constant false
+ acc.host_data dataOperands(%0 : memref<f32>) if(%false) {
+ acc.loop {
+ acc.yield
+ }
+ acc.loop {
+ acc.yield
+ }
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @testhostdataop
+// CHECK-NOT: acc.host_data
+// CHECK: acc.loop
+// CHECK: acc.yield
+// CHECK: acc.loop
+// CHECK: acc.yield
+
+// -----
+
+func.func @testhostdataop(%a: memref<f32>, %ifCond: i1) -> () {
+ %0 = acc.use_device varPtr(%a : memref<f32>) -> memref<f32>
+ %true = arith.constant true
+ acc.host_data dataOperands(%0 : memref<f32>) if(%true) {
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @testhostdataop
+// CHECK: acc.host_data dataOperands(%{{.*}} : memref<f32>) {
More information about the Mlir-commits
mailing list