[Mlir-commits] [mlir] [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (PR #112394)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Oct 15 09:25:09 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/112394

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```


>From c6a91f5e8b458d2316a7464db354774897b685df Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 15 Oct 2024 17:04:56 +0100
Subject: [PATCH] [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern
 patterns

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
---
 .../Transforms/VectorTransferOpTransforms.cpp | 59 +++++++++++++------
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 48 ++++++++++++++-
 2 files changed, 89 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
 /// inserting a memref.subview dropping those unit dims. The vector shapes are
 /// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+    Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
         rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newTransferReadOp = mlir::vector::maskOperation(
+          rewriter, newTransferReadOp, shapeCastMask);
+    }
+
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, vectorType, newTransferReadOp);
-    rewriter.replaceOp(transferReadOp, shapeCast);
+        loc, vectorType, newTransferReadOp->getResults()[0]);
 
-    return success();
+    return shapeCast;
   }
 };
 
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
 /// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferWriteOp.getLoc();
     Value vector = transferWriteOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+    auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+    Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+        loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+        maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newXferWrite =
+          mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+    }
 
-    return success();
+    if (transferWriteOp.hasPureTensorSemantics())
+      return newXferWrite->getResults()[0];
+
+    // With Memref semantics, there's no return value. Use empty value to signal
+    // success.
+    return Value();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+    } : vector<3x2xi1> -> vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:  vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>) {
+
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+func.func @transfer_write_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>,
+      %mask: vector<3x2xi1>) {
+    %c0 = arith.constant 0 : index
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+        vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+    } : vector<3x2xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[VEC:.+]]: vector<3x2xi8>
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list