[Mlir-commits] [mlir] def37f7 - [mlir][vector] add unroll pattern for broadcast (#142011)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 5 10:42:20 PDT 2025


Author: Chao Chen
Date: 2025-06-05T12:42:16-05:00
New Revision: def37f7e3a66601e044ce49c034293e7e32d2a3b

URL: https://github.com/llvm/llvm-project/commit/def37f7e3a66601e044ce49c034293e7e32d2a3b
DIFF: https://github.com/llvm/llvm-project/commit/def37f7e3a66601e044ce49c034293e7e32d2a3b.diff

LOG: [mlir][vector] add unroll pattern for broadcast (#142011)

This PR adds `UnrollBroadcastPattern` to `VectorUnroll` transform. 
To support this, it also extends `BroadcastOp` definition with
`VectorUnrollOpInterface`

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/test/Dialect/Vector/vector-unroll-options.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5e8421ed67d66..8353314ed958b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -347,6 +347,7 @@ def Vector_MultiDimReductionOp :
 
 def Vector_BroadcastOp :
   Vector_Op<"broadcast", [Pure,
+     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
      DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..3179b4f975404 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2522,6 +2522,10 @@ void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 /// Return the dimensions of the result vector that were formerly ones in the
 /// source tensor and thus correspond to "dim-1" broadcasting.
 static llvm::SetVector<int64_t>

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 1cc477d9dca91..fc443ab0d138e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -631,14 +631,78 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
   vector::UnrollVectorOptions options;
 };
 
+struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
+  UnrollBroadcastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::BroadcastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, broadcastOp);
+    if (!targetShape)
+      return failure();
+
+    Location loc = broadcastOp.getLoc();
+    VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType resType = broadcastOp.getResultVectorType();
+    VectorType targetType =
+        resType.cloneWith(*targetShape, resType.getElementType());
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resType, rewriter.getZeroAttr(resType));
+
+    SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
+    SmallVector<int64_t> strides(originalShape.size(), 1);
+
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalShape, *targetShape)) {
+      Value newSrc;
+      if (!srcType) {
+        // Scalar to vector broadcast.
+        newSrc = broadcastOp.getSource();
+      } else {
+        // Vector to vector broadcast.
+        int64_t rank = srcType.getRank();
+        SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
+        SmallVector<int64_t> srcShape(targetShape->end() - rank,
+                                      targetShape->end());
+        SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
+        // adjust the offset and shape for src if the corresponding dim is 1.
+        for (int64_t i = 0; i < rank; ++i) {
+          if (srcType.getDimSize(i) == 1) {
+            srcOffsets[i] = 0;
+            srcShape[i] = 1;
+          }
+        }
+        newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
+      }
+
+      Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
+                                                     newSrc, targetType);
+
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, newOp->getResult(0), result, offsets, strides);
+    }
+
+    rewriter.replaceOp(broadcastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
     RewritePatternSet &patterns, const UnrollVectorOptions &options,
     PatternBenefit benefit) {
-  patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
-               UnrollContractionPattern, UnrollElementwisePattern,
-               UnrollReductionPattern, UnrollMultiReductionPattern,
-               UnrollTransposePattern, UnrollGatherPattern>(
-      patterns.getContext(), options, benefit);
+  patterns
+      .add<UnrollTransferReadPattern, UnrollTransferWritePattern,
+           UnrollContractionPattern, UnrollElementwisePattern,
+           UnrollReductionPattern, UnrollMultiReductionPattern,
+           UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
+          patterns.getContext(), options, benefit);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9c158d05b723c..fbb178fb49d87 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -196,7 +196,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 // CHECK-LABEL: func @negative_vector_fma_3d
 //   CHECK-NOT: vector.extract_strided_slice
 //       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return 
+//       CHECK: return
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -311,3 +311,70 @@ func.func @vector_contract_batched(%lhs: vector<8x8x4xf32>, %rhs: vector<8x8x4xf
 // BATCHED-COUNT-16: vector.contract
 //      BATCHED-NOT: vector.contract
 //          BATCHED: return
+
+
+func.func @vector_broadcast(%v: vector<4xf32>) -> vector<4x4xf32> {
+  %0 = vector.broadcast %v : vector<4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @vector_broadcast
+//  CHECK-SAME: [[arg0:%.+]]: vector<4xf32>
+//       CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+//       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2xf32> to vector<2x2xf32>
+//       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: return [[r3]] : vector<4x4xf32>
+
+func.func @vector_broadcast_with_leading_unit_dim(%v: vector<1x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.broadcast %v : vector<1x4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @vector_broadcast_with_leading_unit_dim
+//  CHECK-SAME: ([[arg0:%.+]]: vector<1x4xf32>) -> vector<4x4xf32> {
+//       CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
+//       CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<1x2xf32> to vector<2x2xf32>
+//       CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
+//       CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<1x2xf32> to vector<2x2xf32>
+//       CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
+//       CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<1x2xf32> to vector<2x2xf32>
+//       CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x4xf32> to vector<1x2xf32>
+//       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<1x2xf32> to vector<2x2xf32>
+//       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: return [[r3]] : vector<4x4xf32>
+
+func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector<4x4xf32> {
+  %0 = vector.broadcast %v : vector<4x1xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func.func @vector_broadcast_with_tailing_unit_dim
+//  CHECK-SAME: ([[arg0:%.+]]: vector<4x1xf32>) -> vector<4x4xf32> {
+//       CHECK: [[c:%.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+//       CHECK: [[s0:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
+//       CHECK: [[b0:%.+]] = vector.broadcast [[s0]] : vector<2x1xf32> to vector<2x2xf32>
+//       CHECK: [[r0:%.+]] = vector.insert_strided_slice [[b0]], [[c]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s1:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
+//       CHECK: [[b1:%.+]] = vector.broadcast [[s1]] : vector<2x1xf32> to vector<2x2xf32>
+//       CHECK: [[r1:%.+]] = vector.insert_strided_slice [[b1]], [[r0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s2:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
+//       CHECK: [[b2:%.+]] = vector.broadcast [[s2]] : vector<2x1xf32> to vector<2x2xf32>
+//       CHECK: [[r2:%.+]] = vector.insert_strided_slice [[b2]], [[r1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: [[s3:%.+]] = vector.extract_strided_slice [[arg0]] {offsets = [2, 0], sizes = [2, 1], strides = [1, 1]} : vector<4x1xf32> to vector<2x1xf32>
+//       CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
+//       CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
+//       CHECK: return [[r3]] : vector<4x4xf32>

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f4f32e9339870..54aa96ba89a00 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -157,12 +157,14 @@ struct TestVectorUnrollingPatterns
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
     populateVectorUnrollPatterns(
-        patterns, UnrollVectorOptions()
-                      .setNativeShape(ArrayRef<int64_t>{2, 2})
-                      .setFilterConstraint([](Operation *op) {
-                        return success(isa<arith::AddFOp, vector::FMAOp,
-                                           vector::MultiDimReductionOp>(op));
-                      }));
+        patterns,
+        UnrollVectorOptions()
+            .setNativeShape(ArrayRef<int64_t>{2, 2})
+            .setFilterConstraint([](Operation *op) {
+              return success(
+                  isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
+                      vector::BroadcastOp>(op));
+            }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{2})


        


More information about the Mlir-commits mailing list