[Mlir-commits] [mlir] [mlir][vector] add unroll pattern for broadcast (PR #142011)

Chao Chen llvmlistbot at llvm.org
Thu May 29 12:33:13 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/142011

>From 032284e64495caabf8d65479103ef00a8e22efff Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 18:39:09 +0000
Subject: [PATCH 1/3] add unroll pattern for broadcast

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  4 ++
 .../Vector/Transforms/VectorUnroll.cpp        | 65 +++++++++++++++++--
 .../Dialect/Vector/vector-unroll-options.mlir | 25 ++++++-
 .../Dialect/Vector/TestVectorTransforms.cpp   | 14 ++--
 5 files changed, 97 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..e50cb459b99ac 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
 
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [Pure,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..4487590bcb9b7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2401,6 +2401,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 /// Return the dimensions of the result vector that were formerly ones in the
 /// source tensor and thus correspond to "dim-1" broadcasting.
 static llvm::SetVector<int64_t>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..1f50de15ad756 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -631,14 +631,69 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+  UnrollBroadcastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, broadcastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = broadcastOp.getLoc();
+    VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType resType = broadcastOp.getResultVectorType();
+    VectorType newType =
+        resType.cloneWith(*targetShape, resType.getElementType());
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+
+    SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+    SmallVector<int64_t> strides(originalShape.size(), 1);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      Value newSrc;
+      // Scalar to vector broadcast.
+      if (!srcType) {
+        newSrc = broadcastOp.getSource();
+      } else {
+        int64_t rank = srcType.getRank();
+        auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
+        auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
+        auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+        newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
+      }
+
+      Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
+                                                     newSrc, newType);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, newOp->getResult(0), result, offsets, strides);
+    }
+
+    rewriter.replaceOp(broadcastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-               UnrollContractionPattern, UnrollElementwisePattern,
-               UnrollReductionPattern, UnrollMultiReductionPattern,
-               UnrollTransposePattern, UnrollGatherPattern>(
-      patterns.getContext(), options, benefit);
+  patterns
+      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+           UnrollContractionPattern, UnrollElementwisePattern,
+           UnrollReductionPattern, UnrollMultiReductionPattern,
+           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
+          patterns.getContext(), options, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9c158d05b723c..fcbf1d13d1cee 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 // CHECK-LABEL: func @negative_vector_fma_3d
 //   CHECK-NOT: vector.extract_strided_slice
 //       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return 
+//       CHECK: return
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,26 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
 // BATCHED-COUNT-16: vector.contract
 //      BATCHED-NOT: vector.contract
 //          BATCHED: return
+
+
+func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
+  %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @vector_broadcast
+//  CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
+//       CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: return [[r3]]
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ccba2e2806862..c8d662c83c3af 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateVectorUnrollPatterns(
-        patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{2, 2})
-                      .setFilterConstraint([](Operation *op) {
-                        return success(isa<arith::AddFOp, vector::FMAOp,
-                                           vector::MultiDimReductionOp>(op));
-                      }));
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{2, 2})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
+                      vector::BroadcastOp>(op));
+            }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{2})

>From df06eea488bf11bd847945fcd29b5bf495680b05 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 19:23:52 +0000
Subject: [PATCH 2/3] fix an error

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1f50de15ad756..6bf7ae290e626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -663,9 +663,9 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
         newSrc = broadcastOp.getSource();
       } else {
         int64_t rank = srcType.getRank();
-        auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).drop_front(rank);
-        auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).drop_front(rank);
-        auto srcStrides = llvm::ArrayRef<int64_t>(strides).drop_front(rank);
+        auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).take_back(rank);
+        auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).take_back(rank);
+        auto srcStrides = llvm::ArrayRef<int64_t>(strides).take_back(rank);
         newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
       }

>From 496f3061cc0cb1e7b8e432064c1b0e2028d8045c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 29 May 2025 19:32:54 +0000
Subject: [PATCH 3/3] fix a bug

---
 .../Dialect/Vector/Transforms/VectorUnroll.cpp  | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 6bf7ae290e626..472262cf5c258 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -658,14 +658,23 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalShape, *targetShape)) {
       Value newSrc;
-      // Scalar to vector broadcast.
       if (!srcType) {
+        // Scalar to vector broadcast.
         newSrc = broadcastOp.getSource();
       } else {
+        // Vector to vector broadcast.
         int64_t rank = srcType.getRank();
-        auto srcOffsets = llvm::ArrayRef<int64_t>(offsets).take_back(rank);
-        auto srcShape = llvm::ArrayRef<int64_t>(*targetShape).take_back(rank);
-        auto srcStrides = llvm::ArrayRef<int64_t>(strides).take_back(rank);
+        SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
+        SmallVector<int64_t> srcShape(targetShape->end() - rank,
+                                      targetShape->end());
+        SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
+        // addjust the offset and shape for src if the corresponding dim is 1.
+        for (int64_t i = 0; i < rank; ++i) {
+          if (srcType.getDimSize(i) == 1) {
+            srcOffsets[i] = 0;
+            srcShape[i] = 1;
+          }
+        }
         newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
       }



More information about the Mlir-commits mailing list