[Mlir-commits] [mlir] 392d55c - [MLIR][GPU] Add canonicalization patterns for folding simple gpu.wait ops.

Uday Bondhugula llvmlistbot at llvm.org
Thu Apr 14 00:02:22 PDT 2022


Author: Arnab Dutta
Date: 2022-04-14T12:30:55+05:30
New Revision: 392d55c1e2d777e9843864e1022d476599334dc9

URL: https://github.com/llvm/llvm-project/commit/392d55c1e2d777e9843864e1022d476599334dc9
DIFF: https://github.com/llvm/llvm-project/commit/392d55c1e2d777e9843864e1022d476599334dc9.diff

LOG: [MLIR][GPU] Add canonicalization patterns for folding simple gpu.wait ops.

* Fold away redundant %t = gpu.wait async + gpu.wait [%t] pairs.

* Fold away %t = gpu.wait async ... ops when %t has no uses.

* Fold away gpu.wait [] ops.

* In case of %t1 = gpu.wait async [%t0], replace all uses of %t1
  with %t0.

Differential Revision: https://reviews.llvm.org/D121878

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Dialect/GPU/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index eaee42bb8e319..078a66dc821b2 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -888,6 +888,8 @@ def GPU_WaitOp : GPU_Op<"wait", [GPU_AsyncOpInterface]> {
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) attr-dict
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def GPU_AllocOp : GPU_Op<"alloc", [

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d351abc883d97..45da4b2af2262 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1184,6 +1184,78 @@ LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
   return foldMemRefCast(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// GPU_WaitOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
+/// %t = gpu.wait async []       // No async dependencies.
+/// ...  gpu.wait ... [%t, ...]  // %t can be removed.
+struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WaitOp op,
+                                PatternRewriter &rewriter) const final {
+    auto predicate = [](Value value) {
+      auto wait_op = value.getDefiningOp<WaitOp>();
+      return wait_op && wait_op->getNumOperands() == 0;
+    };
+    if (llvm::none_of(op.asyncDependencies(), predicate))
+      return failure();
+    SmallVector<Value> validOperands;
+    for (Value operand : op->getOperands()) {
+      if (predicate(operand))
+        continue;
+      validOperands.push_back(operand);
+    }
+    op->setOperands(validOperands);
+    return success();
+  }
+};
+
+/// Simplify trivial gpu.wait ops for the following patterns.
+/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
+/// dependencies).
+/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
+/// %t0.
+/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
+/// dependencies nor return any token.
+struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WaitOp op,
+                                PatternRewriter &rewriter) const final {
+    // Erase gpu.wait ops that neither have any async dependencies nor return
+    // any async token.
+    if (op.asyncDependencies().empty() && !op.asyncToken()) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+    // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
+    if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) {
+      rewriter.replaceOp(op, op.asyncDependencies());
+      return success();
+    }
+    // Erase %t = gpu.wait async ... ops, where %t has no uses.
+    if (op.asyncToken() && op.asyncToken().use_empty()) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+    return failure();
+  }
+};
+
+} // end anonymous namespace
+
+void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GPU_AllocOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir
index 72f67b901d005..225246be49b80 100644
--- a/mlir/test/Dialect/GPU/canonicalize.mlir
+++ b/mlir/test/Dialect/GPU/canonicalize.mlir
@@ -1,5 +1,33 @@
 // RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s
 
+// Fold all the gpu.wait ops as they are redundant.
+// CHECK-LABEL: func @fold_wait_op_test1
+func @fold_wait_op_test1() {
+  %1 = gpu.wait async
+  gpu.wait []
+  %3 = gpu.wait async
+  gpu.wait [%3]
+  return
+}
+// CHECK-NOT: gpu.wait
+
+// Replace uses of gpu.wait op with its async dependency.
+// CHECK-LABEL: func @fold_wait_op_test2
+func @fold_wait_op_test2(%arg0: i1) -> (memref<5xf16>, memref<5xf16>) {
+  %0 = gpu.wait async
+  %memref, %asyncToken = gpu.alloc async [%0] () : memref<5xf16>
+  gpu.wait [%0]
+  %1 = gpu.wait async [%0]
+  %memref_0, %asyncToken_0 = gpu.alloc async [%1] () : memref<5xf16>
+  gpu.wait [%1]
+  return %memref, %memref_0 : memref<5xf16>, memref<5xf16>
+}
+// CHECK-NEXT: %[[TOKEN0:.*]] = gpu.wait async
+// CHECK-NEXT: gpu.alloc async [%[[TOKEN0]]] ()
+// CHECK-NEXT: %[[TOKEN1:.*]] = gpu.wait async
+// CHECK-NEXT: gpu.alloc async [%[[TOKEN1]]] ()
+// CHECK-NEXT: return
+
 // CHECK-LABEL: @memcpy_after_cast
 func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
   // CHECK-NOT: memref.cast


        


More information about the Mlir-commits mailing list