[Mlir-commits] [mlir] [mlir][Vector] Improve support for vector.extract(broadcast) (PR #116234)

Kunwar Grover llvmlistbot at llvm.org
Thu Nov 14 06:04:57 PST 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/116234

This patch improves support for vector.extract(broadcast) dynamic dimension folders. This is mostly a matter of moving a conservative condition for dynamic dimensions.

This patch also improves test coverage for vector.extract + broadcast folders/canonicalizers. The folders/canonicalizers now enumerate every supported / unsupported case.

>From 154c5509c477e0a814e0787e7c579562545c73e6 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 14 Nov 2024 14:02:12 +0000
Subject: [PATCH] [mlir][Vector] Improve dynamic support for
 vector.extract(broadcast) folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 14 ++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 80 +++++++++++++++-------
 2 files changed, 67 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..12f0ae25f4dc7d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1648,10 +1648,6 @@ static bool hasZeroDimVectors(Operation *op) {
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  // TODO: Canonicalization for dynamic position not implemented yet.
-  if (extractOp.hasDynamicPosition())
-    return Value();
-
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
@@ -1680,6 +1676,16 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
 
+  // The dim-1 broadcast -> ExtractOp folder requires in place operation
+  // modifications. For dynamic position, this means we have to change the
+  // number of operands. This cannot be done in place since it changes the
+  // operation storage. For dynamic dimensions, the dim-1 broadcasting should
+  // be implemented as a canonicalization pattern.
+  // TODO: Implement canonicalization pattern for dim-1 broadcasting +
+  // extractop.
+  if (extractOp.hasDynamicPosition())
+    return Value();
+
   auto broadcastOp = cast<vector::BroadcastOp>(defOp);
   int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..766f0e09b6d753 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -652,24 +652,44 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_type
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_type(%a : f32, 
+                                            %idx0 : index, 
+                                            %idx1 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_type_vec
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_type_vec(%a : vector<4xf32>, 
+                                                %idx0 : index) 
+                                                -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_and_scalar
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_and_scalar(%a : vector<f32>, 
+                                                   %idx0 : index, 
+                                                   %idx1 : index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -689,57 +709,71 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  // The indices don't batter for this folder, so we use mixed indices.
+  %r = vector.extract %b[%idx0, %idx1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_vector
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//       CHECK:   return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<4xf32>) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
-  return %r : vector<4xf32>
+  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting_dynamic_nyi
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
+//  CHECK-SAME:   %[[IDX:.*]]: index
+//       CHECK:   %[[B:.*]] = vector.broadcast %[[A]] : vector<4xf32> to vector<1x2x4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[B]][%[[IDX]], 1, 2]
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+// This folder is not yet implemented. Check that this does not fold.
+func.func @fold_extract_broadcast_dim1_broadcasting_dynamic_nyi(
+                                                            %a : vector<4xf32>, 
+                                                            %idx : index) -> f32 {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx, 1, 2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+func.func @canonicalize_extract_broadcast_to_higher_rank(%a : f32, 
+                                                         %idx0 : index) 
+                                                         -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[0, %idx0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: canonicalize_extract_broadcast_to_equal_rank
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+func.func @canonicalize_extract_broadcast_to_equal_rank(%a : vector<1xf32>,
+                                                         %idx0 : index) 
+                                                         -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+  // The indices don't batter for this canonicalizer, so we use mixed indices.
+  %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+
 // -----
 
 // CHECK-LABEL: @fold_extract_shuffle



More information about the Mlir-commits mailing list