[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 29 11:50:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

This PR adds `UnrollBroadcastPattern` to `VectorUnroll` transform. To support this, it also extends `BroadcastOp` definition with `VectorUnrollOpInterface` 

---
Full diff: https://github.com/llvm/llvm-project/pull/142011.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+60-5) 
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+24-1) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+8-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..e50cb459b99ac 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
 
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [Pure,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..4487590bcb9b7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2401,6 +2401,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 /// Return the dimensions of the result vector that were formerly ones in the
 /// source tensor and thus correspond to "dim-1" broadcasting.
 static llvm::SetVector<int64_t>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..1f50de15ad756 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -631,14 +631,69 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+  UnrollBroadcastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, broadcastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = broadcastOp.getLoc();
+    VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType resType = broadcastOp.getResultVectorType();
+    VectorType newType =
+        resType.cloneWith(*targetShape, resType.getElementType());
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+
+    SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+    SmallVector<int64_t> strides(originalShape.size(), 1);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      Value newSrc;
+      // Scalar to vector broadcast.
+      if (!srcType) {
+        newSrc = broadcastOp.getSource();
+      } else {
+        int64_t rank = srcType.getRank();
+        auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
+        auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
+        auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+        newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
+      }
+
+      Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
+                                                     newSrc, newType);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, newOp->getResult(0), result, offsets, strides);
+    }
+
+    rewriter.replaceOp(broadcastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-               UnrollContractionPattern, UnrollElementwisePattern,
-               UnrollReductionPattern, UnrollMultiReductionPattern,
-               UnrollTransposePattern, UnrollGatherPattern>(
-      patterns.getContext(), options, benefit);
+  patterns
+      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+           UnrollContractionPattern, UnrollElementwisePattern,
+           UnrollReductionPattern, UnrollMultiReductionPattern,
+           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
+          patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9c158d05b723c..fcbf1d13d1cee 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 // CHECK-LABEL: func @negative_vector_fma_3d
 //   CHECK-NOT: vector.extract_strided_slice
 //       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return 
+//       CHECK: return
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,26 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
 // BATCHED-COUNT-16: vector.contract
 //      BATCHED-NOT: vector.contract
 //          BATCHED: return
+
+
+func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
+  %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @vector_broadcast
+//  CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
+//       CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: return [[r3]]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ccba2e2806862..c8d662c83c3af 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateVectorUnrollPatterns(
-        patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{2, 2})
-                      .setFilterConstraint([](Operation *op) {
-                        return success(isa<arith::AddFOp, vector::FMAOp,
-                                           vector::MultiDimReductionOp>(op));
-                      }));
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{2, 2})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
+                      vector::BroadcastOp>(op));
+            }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{2})

``````````

</details>


https://github.com/llvm/llvm-project/pull/142011


More information about the Mlir-commits mailing list