[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)
Chao Chen
llvmlistbot at llvm.org
Thu May 29 12:33:13 PDT 2025
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/142011
>From 032284e64495caabf8d65479103ef00a8e22efff Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 18:39:09 +0000
Subject: [PATCH 1/3] add unroll pattern for broadcast
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++
.../Vector/Transforms/VectorUnroll.cpp | 65 +++++++++++++++++--
.../Dialect/Vector/vector-unroll-options.mlir | 25 ++++++-
.../Dialect/Vector/TestVectorTransforms.cpp | 14 ++--
5 files changed, 97 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..e50cb459b99ac 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..4487590bcb9b7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2401,6 +2401,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..1f50de15ad756 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -631,14 +631,69 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+ UnrollBroadcastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, broadcastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = broadcastOp.getLoc();
+ VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType resType = broadcastOp.getResultVectorType();
+ VectorType newType =
+ resType.cloneWith(*targetShape, resType.getElementType());
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+
+ SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value newSrc;
+ // Scalar to vector broadcast.
+ if (!srcType) {
+ newSrc = broadcastOp.getSource();
+ } else {
+ int64_t rank = srcType.getRank();
+ auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
+ auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
+ auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+ newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
+ }
+
+ Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
+ newSrc, newType);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, newOp->getResult(0), result, offsets, strides);
+ }
+
+ rewriter.replaceOp(broadcastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern>(
- patterns.getContext(), options, benefit);
+ patterns
+ .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
+ patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9c158d05b723c..fcbf1d13d1cee 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-// CHECK: return
+// CHECK: return
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,26 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
// BATCHED-COUNT-16: vector.contract
// BATCHED-NOT: vector.contract
// BATCHED: return
+
+
+func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
+ %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @vector_broadcast
+// CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
+// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: return [[r3]]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ccba2e2806862..c8d662c83c3af 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
- patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{2, 2})
- .setFilterConstraint([](Operation *op) {
- return success(isa<arith::AddFOp, vector::FMAOp,
- vector::MultiDimReductionOp>(op));
- }));
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2, 2})
+ .setFilterConstraint([](Operation *op) {
+ return success(
+ isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
+ vector::BroadcastOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
>From df06eea488bf11bd847945fcd29b5bf495680b05 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 19:23:52 +0000
Subject: [PATCH 2/3] fix an error
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1f50de15ad756..6bf7ae290e626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -663,9 +663,9 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
newSrc = broadcastOp.getSource();
} else {
int64_t rank = srcType.getRank();
- auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
- auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
- auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+ auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).take_back(rank);
+ auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).take_back(rank);
+ auto srcStrides = llvm::ArrayRef<int64_t>(strides).take_back(rank);
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
}
>From 496f3061cc0cb1e7b8e432064c1b0e2028d8045c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 19:32:54 +0000
Subject: [PATCH 3/3] fix a bug
---
.../Dialect/Vector/Transforms/VectorUnroll.cpp | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 6bf7ae290e626..472262cf5c258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -658,14 +658,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value newSrc;
- // Scalar to vector broadcast.
if (!srcType) {
+ // Scalar to vector broadcast.
newSrc = broadcastOp.getSource();
} else {
+ // Vector to vector broadcast.
int64_t rank = srcType.getRank();
- auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).take_back(rank);
- auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).take_back(rank);
- auto srcStrides = llvm::ArrayRef<int64_t>(strides).take_back(rank);
+ SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
+ SmallVector<int64_t> srcShape(targetShape->end() - rank,
+ targetShape->end());
+ SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
+ // addjust the offset and shape for src if the corresponding dim is 1.
+ for (int64_t i = 0; i < rank; ++i) {
+ if (srcType.getDimSize(i) == 1) {
+ srcOffsets[i] = 0;
+ srcShape[i] = 1;
+ }
+ }
newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
}
More information about the Mlir-commits
mailing list