[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 29 11:50:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chao Chen (chencha3)
<details>
<summary>Changes</summary>
This PR adds `UnrollBroadcastPattern` to `VectorUnroll` transform. To support this, it also extends `BroadcastOp` definition with `VectorUnrollOpInterface`
---
Full diff: https://github.com/llvm/llvm-project/pull/142011.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+60-5)
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+24-1)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+8-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..e50cb459b99ac 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
def Vector_BroadcastOp :
Vector_Op<"broadcast", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..4487590bcb9b7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2401,6 +2401,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
/// Return the dimensions of the result vector that were formerly ones in the
/// source tensor and thus correspond to "dim-1" broadcasting.
static llvm::SetVector<int64_t>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..1f50de15ad756 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -631,14 +631,69 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
vector::UnrollVectorOptions options;
};
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+ UnrollBroadcastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, broadcastOp);
+ if (!targetShape)
+ return failure();
+
+ Location loc = broadcastOp.getLoc();
+ VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ VectorType resType = broadcastOp.getResultVectorType();
+ VectorType newType =
+ resType.cloneWith(*targetShape, resType.getElementType());
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, resType, rewriter.getZeroAttr(resType));
+
+ SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalShape, *targetShape)) {
+ Value newSrc;
+ // Scalar to vector broadcast.
+ if (!srcType) {
+ newSrc = broadcastOp.getSource();
+ } else {
+ int64_t rank = srcType.getRank();
+ auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
+ auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
+ auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+ newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
+ }
+
+ Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
+ newSrc, newType);
+
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, newOp->getResult(0), result, offsets, strides);
+ }
+
+ rewriter.replaceOp(broadcastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options,
PatternBenefit benefit) {
- patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
- UnrollContractionPattern, UnrollElementwisePattern,
- UnrollReductionPattern, UnrollMultiReductionPattern,
- UnrollTransposePattern, UnrollGatherPattern>(
- patterns.getContext(), options, benefit);
+ patterns
+ .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+ UnrollContractionPattern, UnrollElementwisePattern,
+ UnrollReductionPattern, UnrollMultiReductionPattern,
+ UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
+ patterns.getContext(), options, benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9c158d05b723c..fcbf1d13d1cee 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-// CHECK: return
+// CHECK: return
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,26 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
// BATCHED-COUNT-16: vector.contract
// BATCHED-NOT: vector.contract
// BATCHED: return
+
+
+func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
+ %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @vector_broadcast
+// CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
+// CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
+// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+// CHECK: return [[r3]]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ccba2e2806862..c8d662c83c3af 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
- patterns, UnrollVectorOptions()
- .setNativeShape(ArrayRef<int64_t>{2, 2})
- .setFilterConstraint([](Operation *op) {
- return success(isa<arith::AddFOp, vector::FMAOp,
- vector::MultiDimReductionOp>(op));
- }));
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{2, 2})
+ .setFilterConstraint([](Operation *op) {
+ return success(
+ isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
+ vector::BroadcastOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
``````````
</details>
https://github.com/llvm/llvm-project/pull/142011
More information about the Mlir-commits
mailing list