[Mlir-commits] [mlir] aa4e6a6 - [mlir][openacc] Add canonicalization for standalone data operations for if condition
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 7 08:41:11 PDT 2021
Author: Valentin Clement
Date: 2021-06-07T11:40:59-04:00
New Revision: aa4e6a609acdd00e06b54f525054bd5cf3624f0f
URL: https://github.com/llvm/llvm-project/commit/aa4e6a609acdd00e06b54f525054bd5cf3624f0f
DIFF: https://github.com/llvm/llvm-project/commit/aa4e6a609acdd00e06b54f525054bd5cf3624f0f.diff
LOG: [mlir][openacc] Add canonicalization for standalone data operations for if condition
This patch add canonicalization for the standalone data operation with constant if condition.
It is extracted from this patch D103325.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D103712
Added:
mlir/test/Dialect/OpenACC/canonicalize.mlir
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 37867e6e1998e..3b6756798b2c7 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -283,6 +283,8 @@ def OpenACC_EnterDataOp : OpenACC_Op<"enter_data", [AttrSizedOperandSegments]> {
( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )?
attr-dict-with-keyword
}];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -331,6 +333,8 @@ def OpenACC_ExitDataOp : OpenACC_Op<"exit_data", [AttrSizedOperandSegments]> {
( `detach` `(` $detachOperands^ `:` type($detachOperands) `)` )?
attr-dict-with-keyword
}];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -529,6 +533,8 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
( `device` `(` $deviceOperands^ `:` type($deviceOperands) `)` )?
attr-dict-with-keyword
}];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index b1cc4e796120a..f823041f29221 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -8,9 +8,11 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace acc;
@@ -153,6 +155,31 @@ static bool isComputeOperation(Operation *op) {
return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
}
+namespace {
+/// Pattern to remove operation without region that have constant false `ifCond`
+/// and remove the condition from the operation if the `ifCond` is a true
+/// constant.
+template <typename OpTy>
+struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override {
+ // Early return if there is no condition.
+ if (!op.ifCond())
+ return success();
+
+ auto constOp = op.ifCond().template getDefiningOp<ConstantOp>();
+ if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt())
+ rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
+ else if (constOp)
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
@@ -728,6 +755,11 @@ Value ExitDataOp::getDataOperand(unsigned i) {
return getOperand(waitOperands().size() + numOptional + i);
}
+void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// EnterDataOp
//===----------------------------------------------------------------------===//
@@ -770,6 +802,11 @@ Value EnterDataOp::getDataOperand(unsigned i) {
return getOperand(waitOperands().size() + numOptional + i);
}
+void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// InitOp
//===----------------------------------------------------------------------===//
@@ -836,6 +873,11 @@ Value UpdateOp::getDataOperand(unsigned i) {
numOptional + i);
}
+void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<RemoveConstantIfCondition<UpdateOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// WaitOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
new file mode 100644
index 0000000000000..a74056630e137
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func @testenterdataop(%a: memref<10xf32>) -> () {
+ %ifCond = constant true
+ acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: acc.enter_data create(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testenterdataop(%a: memref<10xf32>) -> () {
+ %ifCond = constant false
+ acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testenterdataop
+// CHECK-NOT: acc.enter_data
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>) -> () {
+ %ifCond = constant true
+ acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: acc.exit_data delete(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>) -> () {
+ %ifCond = constant false
+ acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testexitdataop
+// CHECK-NOT: acc.exit_data
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>) -> () {
+ %ifCond = constant true
+ acc.update if(%ifCond) host(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: acc.update host(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>) -> () {
+ %ifCond = constant false
+ acc.update if(%ifCond) host(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testupdateop
+// CHECK-NOT: acc.update
+
+// ----
+
+func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+ acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK: acc.enter_data if(%{{.*}}) create(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+ acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK: acc.exit_data if(%{{.*}}) delete(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
+ acc.update if(%ifCond) host(%a: memref<10xf32>)
+ return
+}
+
+// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK: acc.update if(%{{.*}}) host(%{{.*}} : memref<10xf32>)
More information about the Mlir-commits
mailing list