[Mlir-commits] [mlir] [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (PR #112394)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sat Oct 19 02:19:18 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/112394

>From 0de9b8c19c716c92c16fdeacd4ef7fac44087d55 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 15 Oct 2024 17:04:56 +0100
Subject: [PATCH 1/2] [MLIR][Vector] Update
 Transfer{Read|Write}DropUnitDimsPattern patterns

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
---
 .../Transforms/VectorTransferOpTransforms.cpp | 59 +++++++++++++------
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 48 ++++++++++++++-
 2 files changed, 89 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
 /// inserting a memref.subview dropping those unit dims. The vector shapes are
 /// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+    Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
         rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newTransferReadOp = mlir::vector::maskOperation(
+          rewriter, newTransferReadOp, shapeCastMask);
+    }
+
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, vectorType, newTransferReadOp);
-    rewriter.replaceOp(transferReadOp, shapeCast);
+        loc, vectorType, newTransferReadOp->getResults()[0]);
 
-    return success();
+    return shapeCast;
   }
 };
 
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
 /// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferWriteOp.getLoc();
     Value vector = transferWriteOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+    auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+    Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+        loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+        maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newXferWrite =
+          mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+    }
 
-    return success();
+    if (transferWriteOp.hasPureTensorSemantics())
+      return newXferWrite->getResults()[0];
+
+    // With Memref semantics, there's no return value. Use empty value to signal
+    // success.
+    return Value();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+    } : vector<3x2xi1> -> vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:  vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>) {
+
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+func.func @transfer_write_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>,
+      %mask: vector<3x2xi1>) {
+    %c0 = arith.constant 0 : index
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+        vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+    } : vector<3x2xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[VEC:.+]]: vector<3x2xi8>
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index

>From 84c4228c27ab58eb5d2a99a6fa02040eee8d8448 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 19 Oct 2024 10:18:39 +0100
Subject: [PATCH 2/2] fixup! [MLIR][Vector] Update
 Transfer{Read|Write}DropUnitDimsPattern patterns

Bail out for with 0D cases
---
 .../Transforms/VectorTransferOpTransforms.cpp | 12 +++++--
 mlir/test/Dialect/Vector/invalid.mlir         |  9 +++++
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 33 +++++++++++++++++++
 3 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 6cdfa645d78f9b..3a30382114c8dc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -378,6 +378,10 @@ class TransferReadDropUnitDimsPattern
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
       return failure();
+    // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
+    // out.
+    if (reducedRank == 0 && maskingOp)
+      return failure();
     // Check if the reduced vector shape matches the reduced source shape.
     // Otherwise, this case is not supported yet.
     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -415,7 +419,7 @@ class TransferReadDropUnitDimsPattern
 
     if (maskingOp) {
       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
-          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
           maskingOp.getMask());
       newTransferReadOp = mlir::vector::maskOperation(
           rewriter, newTransferReadOp, shapeCastMask);
@@ -456,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
       return failure();
+    // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
+    // out.
+    if (reducedRank == 0 && maskingOp)
+      return failure();
     // Check if the reduced vector shape matches the reduced destination shape.
     // Otherwise, this case is not supported yet.
     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -494,7 +502,7 @@ class TransferWriteDropUnitDimsPattern
 
     if (maskingOp) {
       auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
-          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
           maskingOp.getMask());
       newXferWrite =
           mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 36d04bb77e3b96..bebe47ba2db9a8 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->
 
 // -----
 
+func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+  %0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
+  return %0 : vector<f32>
+}
+
+// -----
+
 // expected-note at +1 {{prior use here}}
 func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
   %ft0 = arith.constant 0.0 : f32
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 9b61c8ea76f962..8234351302f6b5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -114,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
 //       CHECK:   %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
 //       CHECK:   vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
 
+func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
+      %arg : memref<1x1x1x1x1xf32>,
+      %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
+
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
+        : memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
+    } : vector<1x1x1xi1> -> vector<1x1x1xf32>
+    return %v : vector<1x1x1xf32>
+}
+// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
+//   CHECK-NOT:   vector.shape_cast
+//   CHECK-NOT:   memref.subview
+
 func.func @transfer_write_and_vector_rank_reducing_to_0d(
       %arg : memref<1x1x1x1x1xf32>,
       %vec : vector<1x1x1xf32>) {
@@ -128,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
 //       CHECK:   %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
 //       CHECK:   vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
 
+func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
+      %arg : memref<1x1x1x1x1xf32>,
+      %vec : vector<1x1x1xf32>,
+      %mask: vector<1x1x1xi1>) {
+
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
+        vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
+    } : vector<1x1x1xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
+//   CHECK-NOT:   vector.shape_cast
+//   CHECK-NOT:   memref.subview
+
 func.func @transfer_read_dynamic_rank_reducing(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
     %c0 = arith.constant 0 : index



More information about the Mlir-commits mailing list