[Mlir-commits] [mlir] 1ee1143 - [mlir][Linalg][Vector] Add forwarding patterns between linalg.copy and vector.transfer

Nicolas Vasilache llvmlistbot at llvm.org
Fri May 29 05:10:25 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-29T08:08:34-04:00
New Revision: 1ee114322cb251f851028c72e7974bf85e707e55

URL: https://github.com/llvm/llvm-project/commit/1ee114322cb251f851028c72e7974bf85e707e55
DIFF: https://github.com/llvm/llvm-project/commit/1ee114322cb251f851028c72e7974bf85e707e55.diff

LOG: [mlir][Linalg][Vector] Add forwarding patterns between linalg.copy and vector.transfer

This revision adds custom rewrites for patterns that arise during linalg structured
ops vectorization. These patterns allow the composition of linalg promotion,
vectorization and removal of redundant copies.

The patterns are voluntarily limited and restrictive atm.
More robust behavior will be implemented once more powerful side effect modeling and analyses are available on view/subview.

On the transfer_read side, the following pattern is rewritten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView ...
   [optional] linalg.fill(%allocOrView, %cst) ...
   ...
   linalg.copy(%in, %subView) ...
   vector.transfer_read %allocOrView[...], %cst ...
```
into
```
   [unchanged] %alloc = ...
   [unchanged] [optional] %view = std.view %alloc ...
   [unchanged] [unchanged] %subView = subview %allocOrView ...
   ...
   vector.transfer_read %in[...], %cst ...
```

On the transfer_write side, the following pattern is rewriten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView...
   ...
   vector.transfer_write %..., %allocOrView[...]
   linalg.copy(%subView, %out)
```

Differential Revision: https://reviews.llvm.org/D80728

Added: 
    mlir/test/Dialect/Linalg/forward-vector-transfers.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2da631956572..2e0673795f30 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -14,6 +14,13 @@
 #include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
+namespace vector {
+
+class TransferReadOp;
+class TransferWriteOp;
+
+} // namespace vector
+
 namespace linalg {
 
 struct LinalgTilingOptions;
@@ -437,6 +444,67 @@ struct LinalgLoweringPattern : public RewritePattern {
   LinalgLoweringType loweringType;
 };
 
+//===----------------------------------------------------------------------===//
+// Op-specific patterns.
+//===----------------------------------------------------------------------===//
+/// Match and rewrite for the pattern:
+/// ```
+///    %alloc = ...
+///    [optional] %view = std.view %alloc ...
+///    %subView = subview %allocOrView ...
+///    [optional] linalg.fill(%allocOrView, %cst) ...
+///    ...
+///    linalg.copy(%in, %subView) ...
+///    vector.transfer_read %allocOrView[...], %cst ...
+/// ```
+/// into
+/// ```
+///    [unchanged] %alloc = ...
+///    [unchanged] [optional] %view = std.view %alloc ...
+///    [unchanged] [unchanged] %subView = subview %allocOrView ...
+///    ...
+///    vector.transfer_read %in[...], %cst ...
+/// ```
+/// Where there is no interleaved use between linalg.copy and transfer_read as
+/// well as no interleaved use between linalg.fill and linalg.copy (if
+/// linalg.fill is specified).
+/// This is a custom rewrite to forward partial reads (with optional fills) to
+/// vector.transfer_read.
+struct LinalgCopyVTRForwardingPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Match and rewrite for the pattern:
+/// ```
+///    %alloc = ...
+///    [optional] %view = std.view %alloc ...
+///    %subView = subview %allocOrView...
+///    ...
+///    vector.transfer_write %..., %allocOrView[...]
+///    linalg.copy(%subView, %out)
+/// ```
+/// into
+/// ```
+///    [unchanged] %alloc = ...
+///    [unchanged] [optional] %view = std.view %alloc ...
+///    [unchanged] %subView = subview %allocOrView...
+///    ...
+///    vector.transfer_write %..., %out[...]
+/// ```
+/// Where there is no interleaved use between transfer_write and linalg.copy.
+/// This is a custom rewrite to forward partial writes to vector.transfer_write.
+struct LinalgCopyVTWForwardingPattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f27baa3c662a..8fa0aa35a874 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -103,12 +103,13 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
       llvm_unreachable("Unexpected conv with padding");
   }
 
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
+  (void)dbgPref;
   edsc::ScopedContext scope(builder, op->getLoc());
   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
     // Vectorize fill as a vector.broadcast.
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-                         "]: Rewrite linalg.fill as vector.broadcast: "
-                      << *op << ":\n");
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
     Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
     Value dst = std_load(memref);
     Value res = vector_broadcast(dst.getType(), fillOp.value());
@@ -117,9 +118,8 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   }
 
   // Vectorize other ops as vector contraction (currently only matmul).
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-                       "]: Rewrite linalg op as vector.contract: "
-                    << *op << ":\n");
+  LLVM_DEBUG(dbgs() << dbgPref
+                    << "Rewrite linalg op as vector.contract: " << *op);
   auto linalgOp = cast<linalg::LinalgOp>(op);
   Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
   Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
@@ -129,3 +129,168 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                               linalgOp.iterator_types());
   std_store(res, memref);
 }
+
+/// Check whether there is any interleaved use of any `values` between `firstOp`
+/// and `secondOp`. Conservatively return `true` if any op or value is in a
+/// 
diff erent block.
+static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
+                                    ValueRange values) {
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
+  (void)dbgPref;
+  if (firstOp->getBlock() != secondOp->getBlock() ||
+      !firstOp->isBeforeInBlock(secondOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << dbgPref << "interleavedUses precondition failed, firstOp: "
+               << *firstOp << ", second op: " << *secondOp);
+    return true;
+  }
+  for (auto v : values) {
+    for (auto &u : v.getUses()) {
+      Operation *owner = u.getOwner();
+      if (owner == firstOp || owner == secondOp)
+        continue;
+      // TODO: this is too conservative, use dominance info in the future.
+      if (owner->getBlock() == firstOp->getBlock() &&
+          (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
+        continue;
+      LLVM_DEBUG(llvm::dbgs()
+                 << dbgPref << " found interleaved op " << *owner
+                 << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
+static SubViewOp getSubViewUseIfUnique(Value v) {
+  SubViewOp subViewOp;
+  for (auto &u : v.getUses()) {
+    if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
+      if (subViewOp)
+        return SubViewOp();
+      subViewOp = newSubViewOp;
+    }
+  }
+  return subViewOp;
+}
+
+/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
+/// when available.
+LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
+    vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
+
+  // Transfer into `view`.
+  Value viewOrAlloc = xferOp.memref();
+  if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
+      !viewOrAlloc.getDefiningOp<AllocOp>())
+    return failure();
+
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
+  (void)dbgPref;
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
+
+  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
+  SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
+  if (!subViewOp)
+    return failure();
+  Value subView = subViewOp.getResult();
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
+
+  // Find the copy into `subView` without interleaved uses.
+  CopyOp copyOp;
+  for (auto &u : subView.getUses()) {
+    if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
+      if (newCopyOp.getOutputBuffer(0) != subView)
+        continue;
+      LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
+      if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
+        continue;
+      copyOp = newCopyOp;
+      break;
+    }
+  }
+  if (!copyOp)
+    return failure();
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
+
+  // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
+  FillOp maybeFillOp;
+  for (auto &u : viewOrAlloc.getUses()) {
+    if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
+      if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
+        continue;
+      LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
+      if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
+        continue;
+      maybeFillOp = newFillOp;
+      break;
+    }
+  }
+  // Ensure padding matches.
+  if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
+    return failure();
+  if (maybeFillOp)
+    LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
+
+  // `in` is the subview that linalg.copy reads. Replace it.
+  Value in = copyOp.getInput(0);
+
+  Value res = rewriter.create<vector::TransferReadOp>(
+      xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
+      xferOp.permutation_map(), xferOp.padding(),
+      xferOp.masked() ? *xferOp.masked() : ArrayAttr());
+
+  if (maybeFillOp)
+    rewriter.eraseOp(maybeFillOp);
+  rewriter.eraseOp(copyOp);
+  rewriter.replaceOp(xferOp, res);
+
+  return success();
+}
+
+/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
+/// when available.
+LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
+    vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
+  // Transfer into `viewOrAlloc`.
+  Value viewOrAlloc = xferOp.memref();
+  if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
+      !viewOrAlloc.getDefiningOp<AllocOp>())
+    return failure();
+
+  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
+  SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
+  if (!subViewOp)
+    return failure();
+  Value subView = subViewOp.getResult();
+
+  // Find the copy from `subView` without interleaved uses.
+  CopyOp copyOp;
+  for (auto &u : subViewOp.getResult().getUses()) {
+    if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
+      if (newCopyOp.getInput(0) != subView)
+        continue;
+      if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
+        continue;
+      copyOp = newCopyOp;
+      break;
+    }
+  }
+  if (!copyOp)
+    return failure();
+
+  // `out` is the subview copied into that we replace.
+  Value out = copyOp.getOutputBuffer(0);
+
+  // Forward vector.transfer into copy.
+  rewriter.create<vector::TransferWriteOp>(
+      xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
+      xferOp.permutation_map(),
+      xferOp.masked() ? *xferOp.masked() : ArrayAttr());
+
+  rewriter.eraseOp(copyOp);
+  rewriter.eraseOp(xferOp);
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
new file mode 100644
index 000000000000..7f56234219fe
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
@@ -0,0 +1,153 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-vector-transfer-forwarding-patterns | FileCheck %s
+
+// CHECK-LABEL: testAllocRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testAllocRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testAllocFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  linalg.fill(%alloc, %f0): memref<32 x f32>, f32
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testViewRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testViewRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<128 x i8>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testViewFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testViewFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.fill(%view, %f0): memref<32 x f32>, f32
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<128 x i8>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testAllocWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]]
+func @testAllocWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32>
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<32 x f32>
+  return
+}
+
+// CHECK-LABEL: testViewWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]]
+func @testViewWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %view[%c0] : vector<32 x f32>, memref<32 x f32>
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<128 x i8>
+  return
+}
+
+///===--------------------------------------------------------------------===///
+// Negative tests
+///===--------------------------------------------------------------------===///
+
+// This should fail the rewrite due to mismatching fill and transfer read value.
+// CHECK-LABEL: failAllocFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: vector.transfer_read %[[ARG0]]
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: linalg.copy
+//       CHECK: vector.transfer_read %[[ALLOC]]
+func @failAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %f1 = constant 1.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  linalg.fill(%alloc, %f0): memref<32 x f32>, f32
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  "some_interleaved_use"(%subview) : (memref<16 x f32>) -> ()
+  %0 = vector.transfer_read %alloc[%c0], %f1: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// This should fail the rewrite due to some interleaved use.
+// CHECK-LABEL: failAllocWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: vector.transfer_write %[[ARG0]], %[[ARG1]]
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ALLOC]]
+//       CHECK: linalg.copy
+func @failAllocWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32>
+  "some_interleaved_use"(%subview) : (memref<16 x f32>) -> ()
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<32 x f32>
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index c38494fe2778..31189f47f9ae 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
@@ -48,6 +49,11 @@ struct TestLinalgTransforms
   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
                                     llvm::cl::desc("Test promotion options"),
                                     llvm::cl::init(false)};
+  Option<bool> testVectorTransferForwardingPatterns{
+      *this, "test-vector-transfer-forwarding-patterns",
+      llvm::cl::desc(
+          "Test a fused pass that forwards linalg.copy to vector.transfer"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -167,19 +173,6 @@ static void applyPatterns(FuncOp funcOp) {
   });
 }
 
-static OwningRewritePatternList
-getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
-  OwningRewritePatternList patterns;
-  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
-  AffineMinOp::getCanonicalizationPatterns(patterns, context);
-  AffineMaxOp::getCanonicalizationPatterns(patterns, context);
-  AllocOp::getCanonicalizationPatterns(patterns, context);
-  SubViewOp::getCanonicalizationPatterns(patterns, context);
-  ViewOp::getCanonicalizationPatterns(patterns, context);
-  MatmulOp::getCanonicalizationPatterns(patterns, context);
-  return patterns;
-}
-
 static void fillL1TilingAndMatmulToVectorPatterns(
     FuncOp funcOp, StringRef startMarker,
     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
@@ -261,40 +254,58 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
       LinalgMarker({"PROMOTE"}));
 }
 
+static void
+applyMatmulToVectorPatterns(FuncOp funcOp,
+                            bool testMatmulToVectorPatterns1dTiling,
+                            bool testMatmulToVectorPatterns2dTiling) {
+  MLIRContext *ctx = funcOp.getContext();
+  SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+  if (testMatmulToVectorPatterns1dTiling) {
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
+  } else if (testMatmulToVectorPatterns2dTiling) {
+    stage1Patterns.emplace_back(
+        LinalgTilingPattern<MatmulOp>(ctx,
+                                      LinalgTilingOptions()
+                                          .setTileSizes({768, 264, 768})
+                                          .setInterchange({1, 2, 0}),
+                                      LinalgMarker({"START"}, "L2")));
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
+  }
+  OwningRewritePatternList stage2Patterns =
+      getLinalgTilingCanonicalizationPatterns(ctx);
+  applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns);
+}
+
+static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
+  OwningRewritePatternList forwardPattern;
+  forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
+  forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
+  applyPatternsAndFoldGreedily(funcOp, forwardPattern);
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
-  if (testPatterns) {
-    applyPatterns(getFunction());
-    return;
-  }
+  auto lambda = [&](void *) {
+    getFunction().walk([](LinalgOp op) {
+      op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+    });
+  };
+  std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
+
   if (testPromotionOptions) {
     OwningRewritePatternList patterns;
     fillPromotionCallBackPatterns(&getContext(), patterns);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
-  } else {
-    SmallVector<OwningRewritePatternList, 4> stage1Patterns;
-    if (testMatmulToVectorPatterns1dTiling) {
-      fillL1TilingAndMatmulToVectorPatterns(getFunction(), "START",
-                                            stage1Patterns);
-    } else if (testMatmulToVectorPatterns2dTiling) {
-      stage1Patterns.emplace_back(
-          LinalgTilingPattern<MatmulOp>(&getContext(),
-                                        LinalgTilingOptions()
-                                            .setTileSizes({768, 264, 768})
-                                            .setInterchange({1, 2, 0}),
-                                        LinalgMarker({"START"}, "L2")));
-      fillL1TilingAndMatmulToVectorPatterns(getFunction(), "L2",
-                                            stage1Patterns);
-    }
-    OwningRewritePatternList stage2Patterns =
-        getMatmulToVectorCanonicalizationPatterns(&getContext());
-    applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
+    return;
   }
-
-  // Drop the marker.
-  getFunction().walk([](LinalgOp op) {
-    op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
-  });
+  if (testPatterns)
+    return applyPatterns(getFunction());
+  if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
+    return applyMatmulToVectorPatterns(getFunction(),
+                                       testMatmulToVectorPatterns1dTiling,
+                                       testMatmulToVectorPatterns2dTiling);
+  if (testVectorTransferForwardingPatterns)
+    return applyVectorTransferForwardingPatterns(getFunction());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list