[Mlir-commits] [mlir] [mlir][Vector] Improve support for vector.extract(broadcast) (PR #116234)

Kunwar Grover llvmlistbot at llvm.org
Wed Feb 19 13:11:47 PST 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/116234

>From e0475be3eeadfbf612b247d6133aac0297b2387d Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 14 Nov 2024 14:02:12 +0000
Subject: [PATCH 1/8] [mlir][Vector] Improve dynamic support for
 vector.extract(broadcast) folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 14 ++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 80 +++++++++++++++-------
 2 files changed, 67 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..afaafe3f842bb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) {
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  // TODO: Canonicalization for dynamic position not implemented yet.
-  if (extractOp.hasDynamicPosition())
-    return Value();
-
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
@@ -1692,6 +1688,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
 
+  // The dim-1 broadcast -> ExtractOp folder requires in place operation
+  // modifications. For dynamic position, this means we have to change the
+  // number of operands. This cannot be done in place since it changes the
+  // operation storage. For dynamic dimensions, the dim-1 broadcasting should
+  // be implemented as a canonicalization pattern.
+  // TODO: Implement canonicalization pattern for dim-1 broadcasting +
+  // extractop.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   auto broadcastOp = cast<vector::BroadcastOp>(defOp);
   int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f17d917ca521e..2fe3d73472018 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,24 +710,44 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_type
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_type(%a : f32, 
+                                            %idx0 : index, 
+                                            %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, 
+                                                %idx0 : index) 
+                                                -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>, 
+                                                   %idx0 : index, 
+                                                   %idx1 : index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -747,57 +767,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//       CHECK:   return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
-  return %r : vector<4xf32>
+  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//  CHECK-SAME:   %[[IDX:.*]]: index
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+// This folder is not yet implemented. Check that this does not fold.
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+                                                            %a : vector<4xf32>, 
+                                                            %idx : index) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32, 
+                                                         %idx0 : index) 
+                                                         -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+                                                         %idx0 : index) 
+                                                         -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+
 // -----
 
 // CHECK-LABEL: @fold_extract_shuffle

>From 55c7536ec2192737217b542ac78f6e7a27ecaf45 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 25 Jan 2025 15:35:24 +0000
Subject: [PATCH 2/8] s

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  2 +-
 mlir/test/Dialect/Vector/canonicalize.mlir | 39 +++++++++++-----------
 2 files changed, 20 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index afaafe3f842bb..e89849f738b83 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1688,7 +1688,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
 
-  // The dim-1 broadcast -> ExtractOp folder requires in place operation
+  // The dim-1 broadcast -> ExtractOp folder requires in-place operation
   // modifications. For dynamic position, this means we have to change the
   // number of operands. This cannot be done in place since it changes the
   // operation storage. For dynamic dimensions, the dim-1 broadcasting should
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2fe3d73472018..e94915d3b9e84 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,43 +710,42 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_same_type
+// CHECK-LABEL: fold_extract_broadcast_same_input_output
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_type(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output(%a : f32, 
                                             %idx0 : index, 
                                             %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  // The indices don't batter for this folder, so we use mixed indices.
+  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, 
-                                                %idx0 : index) 
-                                                -> vector<4xf32> {
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
+                                                %idx0 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  // The indices don't batter for this folder, so we use mixed indices.
+  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
+// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>, 
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
                                                    %idx0 : index, 
                                                    %idx1 : index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  // The indices don't batter for this folder, so we use mixed indices.
+  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
@@ -769,7 +768,7 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 //       CHECK:   return %[[A]] : f32
 func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  // The indices don't batter for this folder, so we use mixed indices.
+  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
@@ -788,14 +787,14 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_negative
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
 //       CHECK:   %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
 //       CHECK:   return %[[R]] : f32
 // This folder is not yet implemented. Check that this does not fold.
-func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_negative(
                                                             %a : vector<4xf32>, 
                                                             %idx : index) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
@@ -805,29 +804,29 @@ func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
 
 // -----
 
-// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
+// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32, 
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, 
                                                          %idx0 : index) 
                                                          -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  // The indices don't matter for this canonicalizer, so we use mixed indices.
   %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
+// CHECK-LABEL: fold_extract_broadcast_to_equal_rank
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
                                                          %idx0 : index) 
                                                          -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  // The indices don't matter for this canonicalizer, so we use mixed indices.
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }

>From 5785c7c34415c20f443bcb2e21eb3cc0f88d7124 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sat, 8 Feb 2025 23:43:26 +0000
Subject: [PATCH 3/8] Allow dim-1 broadcasting + dynamic indices

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 22 +++++++---------------
 mlir/test/Dialect/Vector/canonicalize.mlir | 20 ++------------------
 2 files changed, 9 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index e89849f738b83..7eaf34a55cb4a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1688,16 +1688,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
 
-  // The dim-1 broadcast -> ExtractOp folder requires in-place operation
-  // modifications. For dynamic position, this means we have to change the
-  // number of operands. This cannot be done in place since it changes the
-  // operation storage. For dynamic dimensions, the dim-1 broadcasting should
-  // be implemented as a canonicalization pattern.
-  // TODO: Implement canonicalization pattern for dim-1 broadcasting +
-  // extractop.
-  if (extractOp.hasDynamicPosition())
-    return Value();
-
   auto broadcastOp = cast<vector::BroadcastOp>(defOp);
   int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
 
@@ -1706,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   // extract position to `0` when extracting from the source operand.
   llvm::SetVector<int64_t> broadcastedUnitDims =
       broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
+  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
+  OpBuilder b(extractOp.getContext());
   int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
   for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
     if (broadcastedUnitDims.contains(i))
-      extractPos[i] = 0;
+      extractPos[i] = b.getIndexAttr(0);
   // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
   // matching extract position when extracting from the source operand.
   int64_t rankDiff = broadcastSrcRank - extractResultRank;
   extractPos.erase(extractPos.begin(),
                    std::next(extractPos.begin(), extractPos.size() - rankDiff));
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  OpBuilder b(extractOp.getContext());
-  extractOp.setOperand(0, source);
-  extractOp.setStaticPosition(extractPos);
+  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+  extractOp->setOperands(
+      llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+  extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e94915d3b9e84..76b9d6ad357d1 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -779,25 +779,9 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
-  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
-  return %r : f32
-}
-
-// -----
-
-// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_negative
-//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//  CHECK-SAME:   %[[IDX:.*]]: index
-//       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
-//       CHECK:   return %[[R]] : f32
-// This folder is not yet implemented. Check that this does not fold.
-func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_negative(
-                                                            %a : vector<4xf32>, 
-                                                            %idx : index) -> f32 {
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>, %idx : index) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }

>From 85c2d27652ac64684ee62d06ff93b2d388f4f92b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 9 Feb 2025 00:11:30 +0000
Subject: [PATCH 4/8] Test fixes

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 34 +++++++++++++++++-----
 1 file changed, 27 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 76b9d6ad357d1..b5b337e8a621a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,10 +710,10 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_same_input_output
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output_vec(%a : f32, 
                                             %idx0 : index, 
                                             %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
@@ -752,6 +752,22 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast_diff_input_output_vec
+//  CHECK-SAME:   %[[A:.*]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>, 
+                                                   %idx0 : index, 
+                                                   %idx1 : index) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
+  // The indices don't matter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_negative
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
 //       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
@@ -776,13 +792,17 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
 // -----
 
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
-//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
+//  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>, %idx : index) -> f32 {
-  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
+                                                    %idx : index, 
+                                                    %idx1 : index, 
+                                                    %idx2 : index) -> f32 {
+  %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
-  %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 

>From 0c7a84e6dd76e64a88acd3d24e72d4b216f6bdf7 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 9 Feb 2025 00:14:44 +0000
Subject: [PATCH 5/8] fix name

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b5b337e8a621a..f5dafe35ba6fc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -713,7 +713,7 @@ func.func @fold_extract_transpose(
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_vec(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
                                             %idx0 : index, 
                                             %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>

>From e2a4e6d7901e7f1f1cd4ecb065a33a21f754580f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 19 Feb 2025 20:00:25 +0000
Subject: [PATCH 6/8] be consistent in stye

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 25 ++++++++--------------
 1 file changed, 9 insertions(+), 16 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f5dafe35ba6fc..0ba27d5d3d5c3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -714,8 +714,7 @@ func.func @fold_extract_transpose(
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
 func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
-                                            %idx0 : index, 
-                                            %idx1 : index) -> f32 {
+  %idx0 : index, idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
@@ -728,7 +727,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
 func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
-                                                %idx0 : index) -> vector<4xf32> {
+  %idx0 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
@@ -742,8 +741,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
-                                                   %idx0 : index, 
-                                                   %idx1 : index) -> f32 {
+  %idx0 : index, idx1 : index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
@@ -758,8 +756,7 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>, 
-                                                   %idx0 : index, 
-                                                   %idx1 : index) -> vector<4xf32> {
+  %idx0 : index, idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -797,9 +794,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
 //       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
 func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
-                                                    %idx : index, 
-                                                    %idx1 : index, 
-                                                    %idx2 : index) -> f32 {
+%idx : index, idx1 : index, idx2 : index) -> f32 {
   %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
   // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -811,9 +806,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, 
-                                                         %idx0 : index) 
-                                                         -> vector<4xf32> {
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index) 
+  -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   // The indices don't matter for this canonicalizer, so we use mixed indices.
   %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
@@ -826,9 +820,8 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32,
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
-                                                         %idx0 : index) 
-                                                         -> vector<8xf32> {
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, idx0 : index) 
+  -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
   // The indices don't matter for this canonicalizer, so we use mixed indices.
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>

>From 7900062c39e3d2e114936215f6046e428d88eeeb Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 19 Feb 2025 21:08:10 +0000
Subject: [PATCH 7/8] fix types and use dynamic indices

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 36 +++++++++-------------
 1 file changed, 14 insertions(+), 22 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0ba27d5d3d5c3..98daa57fd6e9b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -714,10 +714,9 @@ func.func @fold_extract_transpose(
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
 func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
-  %idx0 : index, idx1 : index) -> f32 {
+  %idx0 : index, idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
-  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -727,10 +726,9 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
 func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
-  %idx0 : index) -> vector<4xf32> {
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
-  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -741,10 +739,9 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
-  %idx0 : index, idx1 : index) -> f32 {
+  %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
-  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -756,9 +753,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 func.func @fold_extract_broadcast_diff_input_output_vec(%a : vector<2x4xf32>, 
-  %idx0 : index, idx1 : index) -> vector<4xf32> {
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
@@ -779,10 +775,9 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
-  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -791,12 +786,11 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], %[[IDX2]]] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
 func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
-%idx : index, idx1 : index, idx2 : index) -> f32 {
+  %idx : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
-  // The indices don't matter for this folder, so we use mixed indices.
   %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
@@ -806,11 +800,10 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index) 
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) 
   -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  // The indices don't matter for this canonicalizer, so we use mixed indices.
-  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -820,10 +813,9 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, idx0 : index)
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, idx0 : index) 
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) 
   -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  // The indices don't matter for this canonicalizer, so we use mixed indices.
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }

>From 110ccdcf89a32915368ff87f8b9076375b06614b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 19 Feb 2025 21:11:19 +0000
Subject: [PATCH 8/8] fix more

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 98daa57fd6e9b..a57b7bb263f15 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -714,7 +714,7 @@ func.func @fold_extract_transpose(
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
 func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
-  %idx0 : index, idx1 : index, %idx2 : index) -> f32 {
+  %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
@@ -786,7 +786,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], %[[IDX2]]] : f32 from vector<2x1xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
 func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
   %idx : index, %idx1 : index, %idx2 : index) -> f32 {



More information about the Mlir-commits mailing list