[Mlir-commits] [mlir] [mlir][vector] Add leading unit dim folding patterns for masked transfers (PR #71466)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 16:19:31 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.

---
Full diff: https://github.com/llvm/llvm-project/pull/71466.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+59-10) 
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+35) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6bbb293fa2a6b5c..75f32b23e57b0d6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <numeric>
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
     if (read.getTransferRank() == 0)
       return failure();
 
-    if (read.getMask())
-      return failure();
-
     auto shapedType = cast<ShapedType>(read.getSource().getType());
     if (shapedType.getElementType() != read.getVectorType().getElementType())
       return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
       inBoundsAttr = rewriter.getArrayAttr(
           read.getInBoundsAttr().getValue().take_back(newType.getRank()));
 
+    Value mask = Value();
+    if (read.getMask()) {
+      // The mask shape must always match the shape of the written vector, so we
+      // can safely use the same extraction indices.
+      int64_t dropDim = oldType.getRank() - newType.getRank();
+      mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
+                                                splatZero(dropDim));
+    }
+
     auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), newType, read.getSource(), read.getIndices(),
-        AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
-        inBoundsAttr);
+        AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
 
     return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
     if (write.getTransferRank() == 0)
       return failure();
 
-    if (write.getMask())
-      return failure();
-
     auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
     if (shapedType.getElementType() != write.getVectorType().getElementType())
       return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
 
     auto newVector = rewriter.create<vector::ExtractOp>(
         write.getLoc(), write.getVector(), splatZero(dropDim));
+
+    if (write.getMask()) {
+      // The mask shape must always match the shape of the written vector, so we
+      // can safely use the same extraction indices.
+      auto newMask = rewriter.create<vector::ExtractOp>(
+          write.getLoc(), write.getMask(), splatZero(dropDim));
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+          write, newVector, write.getSource(), write.getIndices(),
+          AffineMapAttr::get(newMap), newMask, inBoundsAttr);
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         write, newVector, write.getSource(), write.getIndices(),
         AffineMapAttr::get(newMap), inBoundsAttr);
-
     return success();
   }
 };
@@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
   }
 };
 
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
+struct CastAwayConstantMaskLeadingOneDim
+    : public OpRewritePattern<vector::ConstantMaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
+                                PatternRewriter &rewriter) const override {
+    VectorType oldType = mask.getType();
+    VectorType newType = trimLeadingOneDims(oldType);
+
+    if (newType == oldType)
+      return failure();
+
+    int64_t dropDim = oldType.getRank() - newType.getRank();
+    SmallVector<int64_t> dimSizes;
+    for (auto attr : mask.getMaskDimSizes())
+      dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+
+    // If any of the dropped unit dims has a size of `0`, the entire mask is a
+    // zero mask, else the unit dim has no effect on the mask.
+    int64_t flatLeadingSize =
+        std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
+                        static_cast<int64_t>(1), std::multiplies<int64_t>());
+    SmallVector<int64_t> newDimSizes({flatLeadingSize});
+    newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
+
+    auto newMask = rewriter.create<vector::ConstantMaskOp>(
+        mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
@@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
   patterns
       .add<CastAwayExtractStridedSliceLeadingOneDim,
            CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
-           CastAwayTransferReadLeadingOneDim,
+           CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index e5b27b04dcc8096..5de30206927db2f 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
   return %0: vector<1x4xf16>
 }
 
+// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
+func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+  %f0 = arith.constant 0. : f16
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+  // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+  // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+  %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+  // CHECK: return %[[CAST]]
+  return %0: vector<1x4xf16>
+}
+
 // CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
 func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
   %c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
   return
 }
 
+// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
+func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+  // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+  // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+  vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+  return
+}
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
 func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
   %c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
   return %0: vector<1x1x8x1x[8]xi1>
 }
+
+// CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK:           return %[[BCAST]] : vector<1x1x8x2x1xi1>
+func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+  %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
+  return %0: vector<1x1x8x2x1xi1>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/71466


More information about the Mlir-commits mailing list