[Mlir-commits] [mlir] aa4e6a6 - [mlir][openacc] Add canonicalization for standalone data operations for if condition

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 7 08:41:11 PDT 2021


Author: Valentin Clement
Date: 2021-06-07T11:40:59-04:00
New Revision: aa4e6a609acdd00e06b54f525054bd5cf3624f0f

URL: https://github.com/llvm/llvm-project/commit/aa4e6a609acdd00e06b54f525054bd5cf3624f0f
DIFF: https://github.com/llvm/llvm-project/commit/aa4e6a609acdd00e06b54f525054bd5cf3624f0f.diff

LOG: [mlir][openacc] Add canonicalization for standalone data operations for if condition

This patch add canonicalization for the standalone data operation with constant if condition.
It is extracted from this patch D103325.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D103712

Added: 
    mlir/test/Dialect/OpenACC/canonicalize.mlir

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 37867e6e1998e..3b6756798b2c7 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -283,6 +283,8 @@ def OpenACC_EnterDataOp : OpenACC_Op<"enter_data", [AttrSizedOperandSegments]> {
     ( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )?
     attr-dict-with-keyword
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -331,6 +333,8 @@ def OpenACC_ExitDataOp : OpenACC_Op<"exit_data", [AttrSizedOperandSegments]> {
     ( `detach` `(` $detachOperands^ `:` type($detachOperands) `)` )?
     attr-dict-with-keyword
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -529,6 +533,8 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
     ( `device` `(` $deviceOperands^ `:` type($deviceOperands) `)` )?
     attr-dict-with-keyword
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index b1cc4e796120a..f823041f29221 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -8,9 +8,11 @@
 
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace acc;
@@ -153,6 +155,31 @@ static bool isComputeOperation(Operation *op) {
   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
 }
 
+namespace {
+/// Pattern to remove operation without region that have constant false `ifCond`
+/// and remove the condition from the operation if the `ifCond` is a true
+/// constant.
+template <typename OpTy>
+struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Early return if there is no condition.
+    if (!op.ifCond())
+      return success();
+
+    auto constOp = op.ifCond().template getDefiningOp<ConstantOp>();
+    if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt())
+      rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
+    else if (constOp)
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
@@ -728,6 +755,11 @@ Value ExitDataOp::getDataOperand(unsigned i) {
   return getOperand(waitOperands().size() + numOptional + i);
 }
 
+void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                             MLIRContext *context) {
+  results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // EnterDataOp
 //===----------------------------------------------------------------------===//
@@ -770,6 +802,11 @@ Value EnterDataOp::getDataOperand(unsigned i) {
   return getOperand(waitOperands().size() + numOptional + i);
 }
 
+void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // InitOp
 //===----------------------------------------------------------------------===//
@@ -836,6 +873,11 @@ Value UpdateOp::getDataOperand(unsigned i) {
                     numOptional + i);
 }
 
+void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<RemoveConstantIfCondition<UpdateOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // WaitOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
new file mode 100644
index 0000000000000..a74056630e137
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -0,0 +1,92 @@
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
+
+func @testenterdataop(%a: memref<10xf32>) -> () {
+  %ifCond = constant true
+  acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: acc.enter_data create(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testenterdataop(%a: memref<10xf32>) -> () {
+  %ifCond = constant false
+  acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: func @testenterdataop
+// CHECK-NOT: acc.enter_data
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>) -> () {
+  %ifCond = constant true
+  acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: acc.exit_data delete(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>) -> () {
+  %ifCond = constant false
+  acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: func @testexitdataop
+// CHECK-NOT: acc.exit_data
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>) -> () {
+  %ifCond = constant true
+  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: acc.update host(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>) -> () {
+  %ifCond = constant false
+  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: func @testupdateop
+// CHECK-NOT: acc.update
+
+// ----
+
+func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.enter_data if(%ifCond) create(%a: memref<10xf32>)
+  return
+}
+
+// CHECK:  func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:    acc.enter_data if(%{{.*}}) create(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.exit_data if(%ifCond) delete(%a: memref<10xf32>)
+  return
+}
+
+// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:   acc.exit_data if(%{{.*}}) delete(%{{.*}} : memref<10xf32>)
+
+// -----
+
+func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
+  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  return
+}
+
+// CHECK:  func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:    acc.update if(%{{.*}}) host(%{{.*}} : memref<10xf32>)


        


More information about the Mlir-commits mailing list