[Mlir-commits] [mlir] f3fa54a - [mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern (#130168)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 24 06:14:27 PDT 2025


Author: Kunwar Grover
Date: 2025-03-24T13:14:24Z
New Revision: f3fa54a191d47809c3385bc655d1d42e6732a212

URL: https://github.com/llvm/llvm-project/commit/f3fa54a191d47809c3385bc655d1d42e6732a212
DIFF: https://github.com/llvm/llvm-project/commit/f3fa54a191d47809c3385bc655d1d42e6732a212.diff

LOG: [mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern (#130168)

For vector.extract, the folder always canonicalizes to a vector.extract
operation, while the rewrite pattern canonicalizes to a vector.broadcast
except in the case of 0-rank vectors.

Remove this special casing, and instead handle the 0-rank vector case in
the folder.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 99055e2158230..15e3ce2ff62a2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1675,7 +1675,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return source;
 
   unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank >= broadcastSrcRank)
+  if (extractResultRank > broadcastSrcRank)
     return Value();
   // Check that the dimension of the result haven't been broadcasted.
   auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2156,13 +2156,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     // folding patterns.
     if (extractResultRank < broadcastSrcRank)
       return failure();
+    // For scalar result, the input can only be a rank-0 vector, which will
+    // be handled by the folder.
+    if (extractResultRank == 0)
+      return failure();
 
-    // Special case if broadcast src is a 0D vector.
-    if (extractResultRank == 0) {
-      assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
-      return success();
-    }
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8bb6593d99058..b7db8ec834be7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 
 // CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
-//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
 func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
   %3 = vector.extract %2[] : f32 from vector<f32>
 
   // Broadcast 0D to 3D and extract scalar.
-  // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+  // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
   %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
   %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
 


        


More information about the Mlir-commits mailing list