[Mlir-commits] [mlir] 0cf7aaf - [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (#112394)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Oct 26 05:54:08 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-10-26T13:54:04+01:00
New Revision: 0cf7aaf30067c4be2886a8c9127a27dcbfd63b92

URL: https://github.com/llvm/llvm-project/commit/0cf7aaf30067c4be2886a8c9127a27dcbfd63b92
DIFF: https://github.com/llvm/llvm-project/commit/0cf7aaf30067c4be2886a8c9127a27dcbfd63b92.diff

LOG: [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (#112394)

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..3a30382114c8dc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
 /// inserting a memref.subview dropping those unit dims. The vector shapes are
 /// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -376,6 +378,10 @@ class TransferReadDropUnitDimsPattern
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
       return failure();
+    // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
+    // out.
+    if (reducedRank == 0 && maskingOp)
+      return failure();
     // Check if the reduced vector shape matches the reduced source shape.
     // Otherwise, this case is not supported yet.
     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -406,15 +412,23 @@ class TransferReadDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+    Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
         rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newTransferReadOp = mlir::vector::maskOperation(
+          rewriter, newTransferReadOp, shapeCastMask);
+    }
+
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, vectorType, newTransferReadOp);
-    rewriter.replaceOp(transferReadOp, shapeCast);
+        loc, vectorType, newTransferReadOp->getResults()[0]);
 
-    return success();
+    return shapeCast;
   }
 };
 
@@ -422,11 +436,13 @@ class TransferReadDropUnitDimsPattern
 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
 /// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferWriteOp.getLoc();
     Value vector = transferWriteOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -444,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
       return failure();
+    // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
+    // out.
+    if (reducedRank == 0 && maskingOp)
+      return failure();
     // Check if the reduced vector shape matches the reduced destination shape.
     // Otherwise, this case is not supported yet.
     VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -474,13 +494,26 @@ class TransferWriteDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+    auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+    Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+        loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+        maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newXferWrite =
+          mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+    }
 
-    return success();
+    if (transferWriteOp.hasPureTensorSemantics())
+      return newXferWrite->getResults()[0];
+
+    // With Memref semantics, there's no return value. Use empty value to signal
+    // success.
+    return Value();
   }
 };
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5b0fb537b35655..56039d04549aa5 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->
 
 // -----
 
+func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
+  %ft0 = arith.constant 0.0 : f32
+  // expected-error at +1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+  %0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
+  return %0 : vector<f32>
+}
+
+// -----
+
 // expected-note at +1 {{prior use here}}
 func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
   %ft0 = arith.constant 0.0 : f32

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..8234351302f6b5 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+    } : vector<3x2xi1> -> vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:  vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>) {
+
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+func.func @transfer_write_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>,
+      %mask: vector<3x2xi1>) {
+    %c0 = arith.constant 0 : index
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+        vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+    } : vector<3x2xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[VEC:.+]]: vector<3x2xi8>
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index
@@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
 //       CHECK:   %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
 //       CHECK:   vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
 
+func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
+      %arg : memref<1x1x1x1x1xf32>,
+      %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
+
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
+        : memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
+    } : vector<1x1x1xi1> -> vector<1x1x1xf32>
+    return %v : vector<1x1x1xf32>
+}
+// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
+//   CHECK-NOT:   vector.shape_cast
+//   CHECK-NOT:   memref.subview
+
 func.func @transfer_write_and_vector_rank_reducing_to_0d(
       %arg : memref<1x1x1x1x1xf32>,
       %vec : vector<1x1x1xf32>) {
@@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
 //       CHECK:   %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
 //       CHECK:   vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
 
+func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
+      %arg : memref<1x1x1x1x1xf32>,
+      %vec : vector<1x1x1xf32>,
+      %mask: vector<1x1x1xi1>) {
+
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
+        vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
+    } : vector<1x1x1xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
+//   CHECK-NOT:   vector.shape_cast
+//   CHECK-NOT:   memref.subview
+
 func.func @transfer_read_dynamic_rank_reducing(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
     %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list