[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)
Hugo Trachino
llvmlistbot at llvm.org
Thu Jun 20 07:31:23 PDT 2024
https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/92938
>From d94eb3d81c6e53b90c022ee81c92ac5fac9be79e Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 16 May 2024 19:08:32 +0800
Subject: [PATCH 1/3] [MLIR][Vector] Enable DropUnitDimFromBroadcastOp
---
.../Vector/Transforms/VectorTransforms.cpp | 64 ++++++++++++++++++-
.../Vector/vector-transfer-flatten.mlir | 54 ++++++++++++++++
2 files changed, 116 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 200517913677f..5d8b77621e0d7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+/// vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+/// vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVT)
+ return failure();
+ auto resVT = broadcastOp.getResultVectorType();
+ VectorType newSrcVT = srcVT;
+ VectorType newResVT = resVT;
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+ // Reversing allows us to remove dims from the back without keeping track of
+ // removed dimensions.
+ for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+ if (dim.value() == 1 &&
+ !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
+ !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
+ newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
+ dim.index() - 1);
+ newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
+ dim.index() - 1);
+ }
+ }
+
+ if (newSrcVT == srcVT)
+ return failure();
+ auto loc = broadcastOp->getLoc();
+ auto newSource = rewriter.create<vector::ShapeCastOp>(
+ loc, newSrcVT, broadcastOp.getSource());
+ auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+ newOp.getResult());
+ return success();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1827,8 +1887,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimFromBroadcastOp,
+ ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 42bf7201daaa7..bbeffa979e9c3 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -535,6 +535,60 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
// -----
+func.func @drop_broadcast_unit_dim(%arg0 : vector<1x[1]x3x1xf128>) -> vector<4x1x[1]x3x1xf128> {
+ %bc = vector.broadcast %arg0 : vector<1x[1]x3x1xf128> to vector<4x1x[1]x3x1xf128>
+ return %bc : vector<4x1x[1]x3x1xf128>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[1]x3x1xf128>{{.*}}-> vector<4x1x[1]x3x1xf128> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[1]x3x1xf128> to vector<[1]x3xf128>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<[1]x3xf128> to vector<4x[1]x3xf128>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<4x[1]x3xf128> to vector<4x1x[1]x3x1xf128>
+// CHECK: return %[[VAL_3]] : vector<4x1x[1]x3x1xf128>
+
+// -----
+
+func.func @drop_broadcasted_only_unit_dim(%arg0 : vector<1xf32>) -> vector<1x1xf32> {
+ %bc = vector.broadcast %arg0 : vector<1xf32> to vector<1x1xf32>
+ return %bc : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcasted_only_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1x1xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector<f32>
+// CHECK: %[[VAL_2:.*]] = vector.broadcast %[[VAL_1]] : vector<f32> to vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[VAL_3]] : vector<1x1xf32>
+
+// -----
+
+// Generated unit dimensions through broadcasts are not dropped as we prefer to have a
+// single broadcast rather than a broadcast and a shape_cast.
+func.func @drop_broadcast_generated_unit_dim(%arg0 : vector<4xf32>) -> vector<3x1x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<4xf32> to vector<3x1x4xf32>
+ return %bc : vector<3x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @drop_broadcast_generated_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4xf32>{{.*}}-> vector<3x1x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<4xf32> to vector<3x1x4xf32>
+// CHECK: return %[[VAL_1]] : vector<3x1x4xf32>
+
+// -----
+
+// A broadcasted unit dimension cannot be dropped to prevent type mismatch.
+func.func @drop_broadcasted_unit_dim(%arg0 : vector<2x1x4xf32>) -> vector<2x3x4xf32> {
+ %bc = vector.broadcast %arg0 : vector<2x1x4xf32> to vector<2x3x4xf32>
+ return %bc : vector<2x3x4xf32>
+}
+// CHECK-LABEL: func.func @drop_broadcasted_unit_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2x1x4xf32>{{.*}}-> vector<2x3x4xf32> {
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x4xf32> to vector<2x3x4xf32>
+// CHECK: return %[[VAL_1]] : vector<2x3x4xf32>
+
+// -----
+
func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
>From cebfd7435ab7b63a32cbbc8f428d504ab3028f4c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 20 Jun 2024 18:28:18 +0800
Subject: [PATCH 2/3] Hoist out vector builder
---
.../Vector/Transforms/VectorTransforms.cpp | 32 +++++++++----------
1 file changed, 16 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 5d8b77621e0d7..70349a61652ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1731,33 +1731,33 @@ struct DropUnitDimFromBroadcastOp final
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
- auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
- if (!srcVT)
+ auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVecTy)
return failure();
- auto resVT = broadcastOp.getResultVectorType();
- VectorType newSrcVT = srcVT;
- VectorType newResVT = resVT;
+ auto resVecTy = broadcastOp.getResultVectorType();
+ auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
+ auto resVecTyBuilder = VectorType::Builder(resVecTy);
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
// Reversing allows us to remove dims from the back without keeping track of
// removed dimensions.
- for (const auto &dim : llvm::enumerate(llvm::reverse(srcVT.getShape()))) {
+ for (const auto &dim :
+ llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
if (dim.value() == 1 &&
- !srcVT.getScalableDims()[srcVT.getRank() - dim.index() - 1] &&
- !broadcastedUnitDims.contains(srcVT.getRank() - dim.index() - 1)) {
- newSrcVT = VectorType::Builder(newSrcVT).dropDim(srcVT.getRank() -
- dim.index() - 1);
- newResVT = VectorType::Builder(newResVT).dropDim(resVT.getRank() -
- dim.index() - 1);
+ !srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
+ !broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
+ srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
+ resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
}
}
- if (newSrcVT == srcVT)
+ if (VectorType(srcVecTyBuilder) == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
- loc, newSrcVT, broadcastOp.getSource());
- auto newOp = rewriter.create<vector::BroadcastOp>(loc, newResVT, newSource);
- rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVT,
+ loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
+ auto newOp = rewriter.create<vector::BroadcastOp>(
+ loc, VectorType(resVecTyBuilder), newSource);
+ rewriter.replaceOpWithNewOp<ShapeCastOp>(broadcastOp, resVecTy,
newOp.getResult());
return success();
}
>From e5492969c0dc9aef9e44d978b97f97225b01410f Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 20 Jun 2024 22:31:11 +0800
Subject: [PATCH 3/3] Fixup: name index variables
---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 70349a61652ea..e02ef9d0e1a7d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1740,13 +1740,14 @@ struct DropUnitDimFromBroadcastOp final
auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
// Reversing allows us to remove dims from the back without keeping track of
// removed dimensions.
- for (const auto &dim :
+ for (const auto [reversedIndex, dim] :
llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
- if (dim.value() == 1 &&
- !srcVecTy.getScalableDims()[srcVecTy.getRank() - dim.index() - 1] &&
- !broadcastedUnitDims.contains(srcVecTy.getRank() - dim.index() - 1)) {
- srcVecTyBuilder.dropDim(srcVecTy.getRank() - dim.index() - 1);
- resVecTyBuilder.dropDim(resVecTy.getRank() - dim.index() - 1);
+ unsigned srcDimIndex = srcVecTy.getRank() - reversedIndex - 1;
+ unsigned resDimIndex = resVecTy.getRank() - reversedIndex - 1;
+ if (dim == 1 && !srcVecTy.getScalableDims()[srcDimIndex] &&
+ !broadcastedUnitDims.contains(srcDimIndex)) {
+ srcVecTyBuilder.dropDim(srcDimIndex);
+ resVecTyBuilder.dropDim(resDimIndex);
}
}
More information about the Mlir-commits
mailing list