[Mlir-commits] [mlir] 61fb954 - [mlir][Vector] Improve support for vector.extract(broadcast) (#116234)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 02:47:42 PST 2025


Author: Kunwar Grover
Date: 2025-02-24T16:17:38+05:30
New Revision: 61fb9541095316db352f0e4da855305d1715da10

URL: https://github.com/llvm/llvm-project/commit/61fb9541095316db352f0e4da855305d1715da10
DIFF: https://github.com/llvm/llvm-project/commit/61fb9541095316db352f0e4da855305d1715da10.diff

LOG: [mlir][Vector] Improve support for vector.extract(broadcast) (#116234)

This patch improves support for vector.extract(broadcast) dynamic
dimension folders. This is mostly a matter of moving a conservative
condition for dynamic dimensions. The broadcast folder for
vector.extract now covers the cases that the vector.extractelement +
broadcast folder does.

This patch also improves test coverage for vector.extract + broadcast
folders/canonicalizers. The folders/canonicalizers now enumerate every
supported / unsupported case.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..7eaf34a55cb4a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1660,10 +1660,6 @@ static bool hasZeroDimVectors(Operation *op) {
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  // TODO: Canonicalization for dynamic position not implemented yet.
-  if (extractOp.hasDynamicPosition())
-    return Value();
-
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
@@ -1700,20 +1696,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   // extract position to `0` when extracting from the source operand.
   llvm::SetVector<int64_t> broadcastedUnitDims =
       broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<int64_t> extractPos(extractOp.getStaticPosition());
+  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
+  OpBuilder b(extractOp.getContext());
   int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
   for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
     if (broadcastedUnitDims.contains(i))
-      extractPos[i] = 0;
+      extractPos[i] = b.getIndexAttr(0);
   // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
   // matching extract position when extracting from the source operand.
   int64_t rankDiff = broadcastSrcRank - extractResultRank;
   extractPos.erase(extractPos.begin(),
                    std::next(extractPos.begin(), extractPos.size() - rankDiff));
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  OpBuilder b(extractOp.getContext());
-  extractOp.setOperand(0, source);
-  extractOp.setStaticPosition(extractPos);
+  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+  extractOp->setOperands(
+      llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+  extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index f17d917ca521e..bf755b466c7eb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -710,24 +710,38 @@ func.func @fold_extract_transpose(
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast(%a : f32) -> f32 {
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
+  %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_0dvec
+// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
+  %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
@@ -747,57 +761,68 @@ func.func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32) -> f32 {
+func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast_vector
-//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
-  return %r : vector<4xf32>
+// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
+//  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
+//  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
+//       CHECK:   return %[[R]] : f32
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
+  %idx : index, %idx1 : index, %idx2 : index) -> f32 {
+  %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
+  return %r : f32
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
-//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
-//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : f32 from vector<4xf32>
-//       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
-  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1, 2] : f32 from vector<1x2x4xf32>
-  return %r : f32
+// CHECK-LABEL: fold_extract_broadcast_to_lower_rank
+//  CHECK-SAME:   %[[A:.*]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+// rank(extract_output) < rank(broadcast_input)
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, 
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
-func.func @fold_extract_broadcast(%a : f32) -> vector<4xf32> {
+// rank(extract_output) > rank(broadcast_input)
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) 
+  -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
-  %r = vector.extract %b[0, 1] : vector<4xf32> from vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
 // -----
 
-// CHECK-LABEL: fold_extract_broadcast
+// CHECK-LABEL: fold_extract_broadcast_to_equal_rank
 //  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
-func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
+// rank(extract_output) == rank(broadcast_input)
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) 
+  -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
-  %r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
+  %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
   return %r : vector<8xf32>
 }
+
 // -----
 
 // CHECK-LABEL: @fold_extract_shuffle


        


More information about the Mlir-commits mailing list