[Mlir-commits] [mlir] [mlir][vector] Fix unit dim dropping pattern for masked writes (PR #74038)

Quinn Dawkins llvmlistbot at llvm.org
Thu Nov 30 21:22:05 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/74038

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.

>From 54fdb4e9ec5e8c2e5927c3f290137180c64e69fc Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 30 Nov 2023 23:46:27 -0500
Subject: [PATCH] [mlir][vector] Fix unit dim dropping pattern for masked
 writes

This does the same as #72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
---
 .../Transforms/VectorTransferOpTransforms.cpp | 38 +++++++++-------
 ...ctor-transfer-drop-unit-dims-patterns.mlir | 44 +++++++++++++++++++
 2 files changed, 65 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..0dc097158a4a55d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
-  SmallVector<int64_t> reducedShape;
-  llvm::copy_if(shape, std::back_inserter(reducedShape),
-                [](int64_t dimSize) { return dimSize != 1; });
-  return reducedShape;
-}
-
 /// Converts OpFoldResults to int64_t shape without unit dims.
 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
   SmallVector<int64_t> reducedShape;
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
     // TODO: support tensor type.
-    if (!sourceType || !sourceType.hasStaticShape())
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    if (!sourceType)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
       return failure();
     // Check if the reduced vector shape matches the reduced destination shape.
     // Otherwise, this case is not supported yet.
-    int vectorReducedRank = getReducedRank(vectorType.getShape());
-    if (reducedRank != vectorReducedRank)
+    auto reducedVectorType = trimNonScalableUnitDims(vectorType);
+    if (reducedRank != reducedVectorType.getRank())
       return failure();
     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
           return getConstantIntValue(v) != static_cast<int64_t>(0);
         }))
       return failure();
+
+    Value maskOp = transferWriteOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            transferWriteOp,
+            "unsupported mask op, only 'vector.create_mask' is "
+            "currently supported");
+      FailureOr<Value> rankReducedCreateMask =
+          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      if (failed(rankReducedCreateMask))
+        return failure();
+      maskOp = *rankReducedCreateMask;
+    }
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
-    VectorType reducedVectorType = VectorType::get(
-        getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+    SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
+        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 735915d43565389..d65708068862f46 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
 
+func.func @masked_transfer_write_and_vector_rank_reducing(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>,
+      %mask_dim1 : index,
+      %mask_dim2 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//  CHECK-SAME:     {{.*}}: vector<1x3x1x16x1xf32>,
+//  CHECK-SAME:     %[[MASKDIM1:.+]]: index,
+//  CHECK-SAME:     %[[MASKDIM2:.+]]: index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %vec : vector<[16]x1xi8>,
+      %mask_dim0 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//  CHECK-SAME:     %{{.*}}: vector<[16]x1xi8>,
+//  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+//       CHECK:   %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
 /// Only masks operands of vector.create_mask are currently supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,



More information about the Mlir-commits mailing list