[Mlir-commits] [mlir] [mlir][Vector] Improve support for vector.extract(broadcast) (PR #116234)
Kunwar Grover
llvmlistbot at llvm.org
Sat Feb 8 16:11:51 PST 2025
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/116234
>From eb8ba22f37c7e00598355c1f85b66143e289a541 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 14 Nov 2024 14:02:12 +0000
Subject: [PATCH 1/4] [mlir][Vector] Improve dynamic support for
vector.extract(broadcast) folders
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++--
mlir/test/Dialect/Vector/canonicalize.mlir | 80 +++++++++++++++-------
2 files changed, 67 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b4a5461f4405dcf..c1df64dd3e8820d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) {
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return Value();
-
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
@@ -1692,6 +1688,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
+ // The dim-1 broadcast -> ExtractOp folder requires in place operation
+ // modifications. For dynamic position, this means we have to change the
+ // number of operands. This cannot be done in place since it changes the
+ // operation storage. For dynamic dimensions, the dim-1 broadcasting should
+ // be implemented as a canonicalization pattern.
+ // TODO: Implement canonicalization pattern for dim-1 broadcasting +
+ // extractop.
+ if (extractOp.hasDynamicPosition())
+ return Value();
+
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a74e562ad2f68d7..dec398052d53f3a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,24 +710,44 @@ func.func @fold_extract_transpose(
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_type
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_type(%a : f32,
+ %idx0 : index,
+ %idx1 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>
+// CHECK: return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
+ %idx0 : index)
+ -> vector<4xf32> {
+ %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>,
+ %idx0 : index,
+ %idx1 : index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -747,57 +767,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
// CHECK-LABEL: fold_extract_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ // The indices don't batter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK: return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
- return %r : vector<4xf32>
+ %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK-SAME: %[[IDX:.*]]: index
+// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+// This folder is not yet implemented. Check that this does not fold.
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+ %a : vector<4xf32>,
+ %idx : index) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+ %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32,
+ %idx0 : index)
+ -> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
+ %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+ %idx0 : index)
+ -> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
- %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+ // The indices don't batter for this canonicalizer, so we use mixed indices.
+ %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
+
// -----
// CHECK-LABEL: @fold_extract_shuffle
>From 5f172d8abcef084d08946431c7da0e23013a37b2 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 25 Jan 2025 15:35:24 +0000
Subject: [PATCH 2/4] s
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
mlir/test/Dialect/Vector/canonicalize.mlir | 39 +++++++++++-----------
2 files changed, 20 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c1df64dd3e8820d..1cda5c637874d12 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1688,7 +1688,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
- // The dim-1 broadcast -> ExtractOp folder requires in place operation
+ // The dim-1 broadcast -> ExtractOp folder requires in-place operation
// modifications. For dynamic position, this means we have to change the
// number of operands. This cannot be done in place since it changes the
// operation storage. For dynamic dimensions, the dim-1 broadcasting should
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index dec398052d53f3a..a24c4931f7e29c2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,43 +710,42 @@ func.func @fold_extract_transpose(
// -----
-// CHECK-LABEL: fold_extract_broadcast_same_type
+// CHECK-LABEL: fold_extract_broadcast_same_input_output
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast_same_type(%a : f32,
+func.func @fold_extract_broadcast_same_input_output(%a : f32,
%idx0 : index,
%idx1 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- // The indices don't batter for this folder, so we use mixed indices.
+ // The indices don't matter for this folder, so we use mixed indices.
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>,
- %idx0 : index)
- -> vector<4xf32> {
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
+ %idx0 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- // The indices don't batter for this folder, so we use mixed indices.
+ // The indices don't matter for this folder, so we use mixed indices.
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
-// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
+// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>,
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index,
%idx1 : index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
- // The indices don't batter for this folder, so we use mixed indices.
+ // The indices don't matter for this folder, so we use mixed indices.
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -769,7 +768,7 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
// CHECK: return %[[A]] : f32
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
- // The indices don't batter for this folder, so we use mixed indices.
+ // The indices don't matter for this folder, so we use mixed indices.
%r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -788,14 +787,14 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
// -----
-// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_negative
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK-SAME: %[[IDX:.*]]: index
// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
// CHECK: return %[[R]] : f32
// This folder is not yet implemented. Check that this does not fold.
-func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_negative(
%a : vector<4xf32>,
%idx : index) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
@@ -805,29 +804,29 @@ func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
// -----
-// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
+// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
-func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32,
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32,
%idx0 : index)
-> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
- // The indices don't batter for this canonicalizer, so we use mixed indices.
+ // The indices don't matter for this canonicalizer, so we use mixed indices.
%r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
-// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
+// CHECK-LABEL: fold_extract_broadcast_to_equal_rank
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
-func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
%idx0 : index)
-> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
- // The indices don't batter for this canonicalizer, so we use mixed indices.
+ // The indices don't matter for this canonicalizer, so we use mixed indices.
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
return %r : vector<8xf32>
}
>From 67dc9527a9ad03d2d54757624f0d2c33948ed225 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 8 Feb 2025 23:43:26 +0000
Subject: [PATCH 3/4] Allow dim-1 broadcasting + dynamic indices
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 22 +++++++---------------
mlir/test/Dialect/Vector/canonicalize.mlir | 20 ++------------------
2 files changed, 9 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1cda5c637874d12..2779cfef0dacd3b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1688,16 +1688,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
- // The dim-1 broadcast -> ExtractOp folder requires in-place operation
- // modifications. For dynamic position, this means we have to change the
- // number of operands. This cannot be done in place since it changes the
- // operation storage. For dynamic dimensions, the dim-1 broadcasting should
- // be implemented as a canonicalization pattern.
- // TODO: Implement canonicalization pattern for dim-1 broadcasting +
- // extractop.
- if (extractOp.hasDynamicPosition())
- return Value();
-
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
@@ -1706,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
- SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
+ SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
+ OpBuilder b(extractOp.getContext());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
- extractPos[i] = 0;
+ extractPos[i] = b.getIndexAttr(0);
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
// matching extract position when extracting from the source operand.
int64_t rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
- OpBuilder b(extractOp.getContext());
- extractOp.setOperand(0, source);
- extractOp.setStaticPosition(extractPos);
+ auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+ extractOp->setOperands(
+ llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+ extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a24c4931f7e29c2..0de7813d1b10a57 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -779,25 +779,9 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
- %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
- return %r : f32
-}
-
-// -----
-
-// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_negative
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK-SAME: %[[IDX:.*]]: index
-// CHECK: %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
-// CHECK: return %[[R]] : f32
-// This folder is not yet implemented. Check that this does not fold.
-func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_negative(
- %a : vector<4xf32>,
- %idx : index) -> f32 {
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>, %idx : index) -> f32 {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+ // The indices don't matter for this folder, so we use mixed indices.
%r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
return %r : f32
}
>From c0f8ffebfe7bd5e9e93037504cffe5aa871b363b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 9 Feb 2025 00:11:30 +0000
Subject: [PATCH 4/4] Test fixes
---
mlir/test/Dialect/Vector/canonicalize.mlir | 34 +++++++++++++++++-----
1 file changed, 27 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0de7813d1b10a57..33dc2804f7f0897 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,10 +710,10 @@ func.func @fold_extract_transpose(
// -----
-// CHECK-LABEL: fold_extract_broadcast_same_input_output
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output(%a : f32,
+func.func @fold_extract_broadcast_same_input_output_vec(%a : f32,
%idx0 : index,
%idx1 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
@@ -752,6 +752,22 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
+// CHECK-LABEL: fold_extract_broadcast_diff_input_output_vec
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>,
+ %idx0 : index,
+ %idx1 : index) -> vector<4xf32> {
+ %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
+ // The indices don't matter for this folder, so we use mixed indices.
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_negative
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -776,13 +792,17 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
// -----
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
-// CHECK-SAME: %[[A:.*]]: vector<4xf32>
-// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
+// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
+// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>, %idx : index) -> f32 {
- %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
+ %idx : index,
+ %idx1 : index,
+ %idx2 : index) -> f32 {
+ %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
// The indices don't matter for this folder, so we use mixed indices.
- %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
+ %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
More information about the Mlir-commits
mailing list