[Mlir-commits] [mlir] [mlir][vector] Add leading unit dim folding patterns for masked transfers (PR #71466)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 6 16:19:31 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
This handles `vector.transfer_read`, `vector.transfer_write`, and `vector.constant_mask`. The unit dims are only relevant for masks created by `create_mask` and `constant_mask` if the mask size for the unit dim is non-one, in which case all subsequent sizes must also be zero. From the perspective of the vector transfers, however, these unit dims can just be dropped directly.
---
Full diff: https://github.com/llvm/llvm-project/pull/71466.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+59-10)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+35)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6bbb293fa2a6b5c..75f32b23e57b0d6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include <numeric>
+
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -208,9 +210,6 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getTransferRank() == 0)
return failure();
- if (read.getMask())
- return failure();
-
auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -233,10 +232,18 @@ struct CastAwayTransferReadLeadingOneDim
inBoundsAttr = rewriter.getArrayAttr(
read.getInBoundsAttr().getValue().take_back(newType.getRank()));
+ Value mask = Value();
+ if (read.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
+ splatZero(dropDim));
+ }
+
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), newType, read.getSource(), read.getIndices(),
- AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(),
- inBoundsAttr);
+ AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
return success();
@@ -256,9 +263,6 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getTransferRank() == 0)
return failure();
- if (write.getMask())
- return failure();
-
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -283,10 +287,21 @@ struct CastAwayTransferWriteLeadingOneDim
auto newVector = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getVector(), splatZero(dropDim));
+
+ if (write.getMask()) {
+ // The mask shape must always match the shape of the written vector, so we
+ // can safely use the same extraction indices.
+ auto newMask = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.getMask(), splatZero(dropDim));
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ write, newVector, write.getSource(), write.getIndices(),
+ AffineMapAttr::get(newMap), newMask, inBoundsAttr);
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), inBoundsAttr);
-
return success();
}
};
@@ -467,6 +482,40 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
}
};
+// Drops leading 1 dimensions from vector.constant_mask and inserts a
+// vector.broadcast back to the original shape.
+struct CastAwayConstantMaskLeadingOneDim
+ : public OpRewritePattern<vector::ConstantMaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
+ PatternRewriter &rewriter) const override {
+ VectorType oldType = mask.getType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+ SmallVector<int64_t> dimSizes;
+ for (auto attr : mask.getMaskDimSizes())
+ dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+
+ // If any of the dropped unit dims has a size of `0`, the entire mask is a
+ // zero mask, else the unit dim has no effect on the mask.
+ int64_t flatLeadingSize =
+ std::accumulate(dimSizes.begin(), dimSizes.begin() + dropDim + 1,
+ static_cast<int64_t>(1), std::multiplies<int64_t>());
+ SmallVector<int64_t> newDimSizes({flatLeadingSize});
+ newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
+
+ auto newMask = rewriter.create<vector::ConstantMaskOp>(
+ mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
@@ -474,7 +523,7 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns
.add<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
+ CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
populateShapeCastFoldingPatterns(patterns, benefit);
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index e5b27b04dcc8096..5de30206927db2f 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -209,6 +209,20 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>)
return %0: vector<1x4xf16>
}
+// CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
+func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
+ %f0 = arith.constant 0. : f16
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
+ // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
+ %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
+ // CHECK: return %[[CAST]]
+ return %0: vector<1x4xf16>
+}
+
// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
%c0 = arith.constant 0 : index
@@ -229,6 +243,18 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>
return
}
+// CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
+func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
+ // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
+ // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
+
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
+ return
+}
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
%c0 = arith.constant 0 : index
@@ -410,3 +436,12 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x
%0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
return %0: vector<1x1x8x1x[8]xi1>
}
+
+// CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+// CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
+// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1>
+func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
+ %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
+ return %0: vector<1x1x8x2x1xi1>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/71466
More information about the Mlir-commits
mailing list