[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Tue Sep 30 04:18:51 PDT 2025
https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/160318
>From 691bf891868130e8eb66953de2648e1f1befcd31 Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Tue, 23 Sep 2025 14:59:27 +0000
Subject: [PATCH 1/4] Added folder for broadcast->to_elements
Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 94 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 26 ++++++
2 files changed, 119 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 347141e2773b8..4ac61418b97a5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}
+/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
+/// vector.to_elements on the source and remapping results according to
+/// broadcast semantics.
+///
+/// Cases handled:
+/// - %x is a scalar: replicate the scalar across all results.
+/// - %x is a vector: create to_elements on source and remap/duplicate results.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+ // Bail on scalable vectors.
+ if (resultVecType.getNumScalableDims() != 0)
+ return failure();
+
+ // Case 1: scalar broadcast → replicate scalar across all results.
+ if (!isa<VectorType>(bcastOp.getSource().getType())) {
+ Value scalar = bcastOp.getSource();
+ results.assign(resultVecType.getNumElements(), scalar);
+ return success();
+ }
+
+ // Case 2: vector broadcast → create to_elements on source and remap.
+ auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
+ if (srcVecType.getNumScalableDims() != 0)
+ return failure();
+
+ // Create a temporary to_elements to get the source elements for mapping.
+ // Change the operand to the broadcast source.
+ OpBuilder builder(toElementsOp);
+ auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
+ bcastOp.getSource());
+
+ ArrayRef<int64_t> dstShape = resultVecType.getShape();
+ ArrayRef<int64_t> srcShape = srcVecType.getShape();
+
+ // Quick broadcastability check with right-aligned shapes.
+ unsigned dstRank = dstShape.size();
+ unsigned srcRank = srcShape.size();
+ if (srcRank > dstRank)
+ return failure();
+
+ for (unsigned i = 0; i < dstRank; ++i) {
+ int64_t dstDim = dstShape[i];
+ int64_t srcDim = 1;
+ if (i + srcRank >= dstRank)
+ srcDim = srcShape[i + srcRank - dstRank];
+ if (!(srcDim == 1 || srcDim == dstDim))
+ return failure();
+ }
+
+ int64_t dstCount = 1;
+ for (int64_t v : dstShape)
+ dstCount *= v;
+ results.clear();
+ results.reserve(dstCount);
+
+ // Pre-compute the mapping from destination linear index to source linear index
+ SmallVector<int64_t> dstToSrcMap(dstCount);
+ SmallVector<int64_t> dstIdx(dstShape.size());
+
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ // Convert linear index to multi-dimensional indices (row-major order)
+ int64_t temp = lin;
+ for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+ int64_t dim = dstShape[i];
+ dstIdx[i] = temp % dim;
+ temp /= dim;
+ }
+ // Right-align mapping from dst indices to src indices.
+ int64_t srcLin = 0;
+ for (unsigned k = 0; k < srcRank; ++k)
+ srcLin = srcLin * srcShape[k] +
+ ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+ dstToSrcMap[lin] = srcLin;
+ }
+
+ // Apply the pre-computed mapping
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ results.push_back(srcElems.getResult(dstToSrcMap[lin]));
+ }
+ return success();
+}
+
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- return foldToElementsFromElements(*this, results);
+ if (succeeded(foldToElementsFromElements(*this, results)))
+ return success();
+ return foldToElementsOfBroadcast(*this, results);
}
+
LogicalResult
ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ToElementsOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 08d28be3f8f73..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3326,6 +3326,32 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
// -----
+// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
+ %v = vector.broadcast %s : f32 to vector<4xf32>
+ %e:4 = vector.to_elements %v : vector<4xf32>
+ // CHECK-NOT: vector.broadcast
+ // CHECK-NOT: vector.to_elements
+ // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
+ return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+ %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+ %e:6 = vector.to_elements %v : vector<3x2xf32>
+ // CHECK-NOT: vector.broadcast
+ // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+ // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+ return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
// +---------------------------------------------------------------------------
// Tests for foldFromElementsToConstant
// +---------------------------------------------------------------------------
>From c80ccba208a3f2fcf7f7f9b92b03b60fa9168cde Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 01:54:45 -0700
Subject: [PATCH 2/4] Vector broadcast->to_elements requires a canonicalizer
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 153 ++++++++++--------
mlir/test/Dialect/Vector/canonicalize.mlir | 25 +++
3 files changed, 113 insertions(+), 66 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..61fde8c0ba8c5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -805,6 +805,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
let results = (outs Variadic<AnyType>:$elements);
let assemblyFormat = "$source attr-dict `:` type($source)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Vector_FromElementsOp : Vector_Op<"from_elements", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4ac61418b97a5..fa40af6b16528 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,13 +2395,12 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}
-/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
-/// vector.to_elements on the source and remapping results according to
-/// broadcast semantics.
+/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
///
/// Cases handled:
/// - %x is a scalar: replicate the scalar across all results.
-/// - %x is a vector: create to_elements on source and remap/duplicate results.
+///
+/// The vector source case is handled by a canonicalization pattern.
static LogicalResult
foldToElementsOfBroadcast(ToElementsOp toElementsOp,
SmallVectorImpl<OpFoldResult> &results) {
@@ -2421,67 +2420,8 @@ foldToElementsOfBroadcast(ToElementsOp toElementsOp,
return success();
}
- // Case 2: vector broadcast → create to_elements on source and remap.
- auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
- if (srcVecType.getNumScalableDims() != 0)
- return failure();
-
- // Create a temporary to_elements to get the source elements for mapping.
- // Change the operand to the broadcast source.
- OpBuilder builder(toElementsOp);
- auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
- bcastOp.getSource());
-
- ArrayRef<int64_t> dstShape = resultVecType.getShape();
- ArrayRef<int64_t> srcShape = srcVecType.getShape();
-
- // Quick broadcastability check with right-aligned shapes.
- unsigned dstRank = dstShape.size();
- unsigned srcRank = srcShape.size();
- if (srcRank > dstRank)
- return failure();
-
- for (unsigned i = 0; i < dstRank; ++i) {
- int64_t dstDim = dstShape[i];
- int64_t srcDim = 1;
- if (i + srcRank >= dstRank)
- srcDim = srcShape[i + srcRank - dstRank];
- if (!(srcDim == 1 || srcDim == dstDim))
- return failure();
- }
-
- int64_t dstCount = 1;
- for (int64_t v : dstShape)
- dstCount *= v;
- results.clear();
- results.reserve(dstCount);
-
- // Pre-compute the mapping from destination linear index to source linear index
- SmallVector<int64_t> dstToSrcMap(dstCount);
- SmallVector<int64_t> dstIdx(dstShape.size());
-
- for (int64_t lin = 0; lin < dstCount; ++lin) {
- // Convert linear index to multi-dimensional indices (row-major order)
- int64_t temp = lin;
- for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
- int64_t dim = dstShape[i];
- dstIdx[i] = temp % dim;
- temp /= dim;
- }
- // Right-align mapping from dst indices to src indices.
- int64_t srcLin = 0;
- for (unsigned k = 0; k < srcRank; ++k)
- srcLin = srcLin * srcShape[k] +
- ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
-
- dstToSrcMap[lin] = srcLin;
- }
-
- // Apply the pre-computed mapping
- for (int64_t lin = 0; lin < dstCount; ++lin) {
- results.push_back(srcElems.getResult(dstToSrcMap[lin]));
- }
- return success();
+ // Vector source case is not folded here.
+ return failure();
}
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
@@ -2491,7 +2431,6 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsOfBroadcast(*this, results);
}
-
LogicalResult
ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ToElementsOp::Adaptor adaptor,
@@ -2502,6 +2441,88 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+namespace {
+
+struct ToElementsOfVectorBroadcast final
+ : public OpRewritePattern<ToElementsOp> {
+ using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ // Only handle broadcasts from a vector source here.
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ // Bail on scalable vectors.
+ if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
+ return failure();
+
+ ArrayRef<int64_t> dstShape = dstType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+
+ unsigned dstRank = dstShape.size();
+ unsigned srcRank = srcShape.size();
+ if (srcRank > dstRank)
+ return failure();
+
+ // Verify broadcastability (right-aligned)
+ for (unsigned i = 0; i < dstRank; ++i) {
+ int64_t dstDim = dstShape[i];
+ int64_t srcDim = 1;
+ if (i + srcRank >= dstRank)
+ srcDim = srcShape[i + srcRank - dstRank];
+ if (!(srcDim == 1 || srcDim == dstDim))
+ return failure();
+ }
+
+ // Create elements for the broadcast source vector.
+ auto loc = toElementsOp.getLoc();
+ auto srcElems = rewriter.create<ToElementsOp>(loc, bcastOp.getSource());
+
+ int64_t dstCount = 1;
+ for (int64_t v : dstShape)
+ dstCount *= v;
+
+ SmallVector<Value> replacements;
+ replacements.reserve(dstCount);
+
+ // Pre-compute and apply mapping from destination linear index to
+ // source linear index (row-major, right-aligned broadcasting).
+ SmallVector<int64_t> dstIdx(dstShape.size());
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ int64_t temp = lin;
+ for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+ int64_t dim = dstShape[i];
+ dstIdx[i] = temp % dim;
+ temp /= dim;
+ }
+ int64_t srcLin = 0;
+ for (unsigned k = 0; k < srcRank; ++k)
+ srcLin = srcLin * srcShape[k] +
+ ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+ replacements.push_back(srcElems.getResult(srcLin));
+ }
+
+ rewriter.replaceOp(toElementsOp, replacements);
+ return success();
+ }
+};
+
+} // end anonymous namespace
+
+void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ToElementsOfVectorBroadcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 728c4ddd22ec7..ab6d2ec9a835b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3816,3 +3816,28 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
return %v_1 : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: @fold_to_elements_of_scalar_broadcast
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @fold_to_elements_of_scalar_broadcast(%s: f32) -> (f32, f32, f32, f32) {
+ %v = vector.broadcast %s : f32 to vector<4xf32>
+ %e:4 = vector.to_elements %v : vector<4xf32>
+ // CHECK-NOT: vector.broadcast
+ // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]] : f32, f32, f32, f32
+ return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @canonicalize_to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+ %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+ %e:6 = vector.to_elements %v : vector<3x2xf32>
+ // CHECK-NOT: vector.broadcast
+ // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+ // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+ return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
>From 5d04df4fbb5c4bee4093a4024414685f53ec9843 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 01:59:50 -0700
Subject: [PATCH 3/4] Removed duplicate lit
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/test/Dialect/Vector/canonicalize.mlir | 25 ----------------------
1 file changed, 25 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ab6d2ec9a835b..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3816,28 +3816,3 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
return %v_1 : vector<4xf32>
}
-
-// -----
-
-// CHECK-LABEL: @fold_to_elements_of_scalar_broadcast
-// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
-func.func @fold_to_elements_of_scalar_broadcast(%s: f32) -> (f32, f32, f32, f32) {
- %v = vector.broadcast %s : f32 to vector<4xf32>
- %e:4 = vector.to_elements %v : vector<4xf32>
- // CHECK-NOT: vector.broadcast
- // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]] : f32, f32, f32, f32
- return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: @canonicalize_to_elements_of_vector_broadcast
-// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
-func.func @canonicalize_to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
- %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
- %e:6 = vector.to_elements %v : vector<3x2xf32>
- // CHECK-NOT: vector.broadcast
- // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
- // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
- return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
-}
>From 728f51a605324679f1d1cf07daabc307e1b2773f Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 04:07:32 -0700
Subject: [PATCH 4/4] > Addressed review comments: 1. Better comments, removing
keywords and verbose docs. 2. Removed redundant "Broadcastability" check, we
don't require it since the vector.broadcast op will always be valid when it
reaches this logic.
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 49 ++++++++++--------------
1 file changed, 20 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fa40af6b16528..04a6fbd74245f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -47,6 +47,7 @@
#include <cassert>
#include <cstdint>
+#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
@@ -2409,7 +2410,8 @@ foldToElementsOfBroadcast(ToElementsOp toElementsOp,
return failure();
auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
- // Bail on scalable vectors.
+ // Bail on scalable vectors, since the element count and per-dimension extents
+ // must be known at compile time.
if (resultVecType.getNumScalableDims() != 0)
return failure();
@@ -2441,11 +2443,8 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
-namespace {
-
-struct ToElementsOfVectorBroadcast final
- : public OpRewritePattern<ToElementsOp> {
- using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
PatternRewriter &rewriter) const override {
@@ -2467,34 +2466,28 @@ struct ToElementsOfVectorBroadcast final
ArrayRef<int64_t> dstShape = dstType.getShape();
ArrayRef<int64_t> srcShape = srcType.getShape();
- unsigned dstRank = dstShape.size();
- unsigned srcRank = srcShape.size();
+ int64_t dstRank = dstShape.size();
+ int64_t srcRank = srcShape.size();
if (srcRank > dstRank)
return failure();
- // Verify broadcastability (right-aligned)
- for (unsigned i = 0; i < dstRank; ++i) {
- int64_t dstDim = dstShape[i];
- int64_t srcDim = 1;
- if (i + srcRank >= dstRank)
- srcDim = srcShape[i + srcRank - dstRank];
- if (!(srcDim == 1 || srcDim == dstDim))
- return failure();
- }
-
// Create elements for the broadcast source vector.
- auto loc = toElementsOp.getLoc();
- auto srcElems = rewriter.create<ToElementsOp>(loc, bcastOp.getSource());
+ auto srcElems = rewriter.create<ToElementsOp>(toElementsOp.getLoc(),
+ bcastOp.getSource());
- int64_t dstCount = 1;
- for (int64_t v : dstShape)
- dstCount *= v;
+ int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+ std::multiplies<int64_t>());
SmallVector<Value> replacements;
replacements.reserve(dstCount);
- // Pre-compute and apply mapping from destination linear index to
- // source linear index (row-major, right-aligned broadcasting).
+ // For each element of the destination, determine which element of the
+ // source should be used. We walk all destination positions using a single
+ // counter, decode it into per-dimension indices, then build the matching
+ // source position: use the same index where sizes match, and use 0 where
+ // the source size is 1 (replication). This mapping is needed so we can
+ // replace each result of to_elements with the corresponding element from
+ // the broadcast source.
SmallVector<int64_t> dstIdx(dstShape.size());
for (int64_t lin = 0; lin < dstCount; ++lin) {
int64_t temp = lin;
@@ -2504,7 +2497,7 @@ struct ToElementsOfVectorBroadcast final
temp /= dim;
}
int64_t srcLin = 0;
- for (unsigned k = 0; k < srcRank; ++k)
+ for (int64_t k = 0; k < srcRank; ++k)
srcLin = srcLin * srcShape[k] +
((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
@@ -2516,11 +2509,9 @@ struct ToElementsOfVectorBroadcast final
}
};
-} // end anonymous namespace
-
void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ToElementsOfVectorBroadcast>(context);
+ results.add<ToElementsOfBroadcast>(context);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list