[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 21 10:11:53 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector
Author: Hugo Trachino (nujaa)
<details>
<summary>Changes</summary>
This MR is part of a list of MRs aiming to generalize `DropUnitDimFromElementwiseOps` for other ops.
This commit implements `DropUnitDimFromBroadcastOp` to target `vector::BroadcastOp`.
Discussed [here](https://discourse.llvm.org/t/mlir-for-arm-sme-vectorizing-matmul-like-ops-as-part-of-a-broader-program/78603).
---
Full diff: https://github.com/llvm/llvm-project/pull/92938.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+62-2)
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+55)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..a8494eac3e5aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1695,6 +1695,66 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+/// vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+/// vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVT)
+ return failure();
+ auto resVT = broadcastOp.getResultVectorType();
+ VectorType newSrcVT = srcVT;
+ VectorType newResVT = resVT;
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+ // Reversing allows us to remove dims from the back without keeping track of
+ // removed dimensions.
+ for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+ if (dim.value() == 1 &&
+ !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+ !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+ newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+ dim.index() - 1);
+ newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+ dim.index() - 1);
+ }
+ }
+
+ if (newSrcVT == srcVT)
+ return failure();
+ auto loc = broadcastOp->getLoc();
+ auto newSource = rewriter.create<vector::ShapeCastOp>(
+ loc, newSrcVT, broadcastOp.getSource());
+ auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+ newOp.getResult());
+ return success();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1819,8 +1879,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+ ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed..f1fc443b9d4bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -460,6 +460,61 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-NOT: memref.collapse_shape
+// -----
+
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+ %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+ return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+ %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+ return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+ return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+ return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>
+
+
// -----
func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/92938
More information about the Mlir-commits
mailing list