[Mlir-commits] [mlir] [mlir][Vector] Handle 0-rank case in fold instead of RewriterPattern (PR #130168)
Kunwar Grover
llvmlistbot at llvm.org
Thu Mar 6 12:10:38 PST 2025
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/130168
For vector.extract, the folder always canonicalizes to a vector.extract operation, while the rewrite pattern canonicalizes to a vector.broadcast except in the case of 0-rank vectors.
Remove this special casing, and instead handle the 0-rank vector case in the folder.
>From 18f90d02dd4447b201f963f392cc9ed73e2de8dd Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 6 Mar 2025 20:08:19 +0000
Subject: [PATCH] [mlir][Vector] Handle 0-rank case in fold instead of
RewriterPattern
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 +++++-------
mlir/test/Dialect/Vector/canonicalize.mlir | 4 ++--
2 files changed, 7 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..31f1e82ff1174 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1678,7 +1678,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return source;
unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank >= broadcastSrcRank)
+ if (extractResultRank > broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
@@ -2159,13 +2159,11 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
+ // For scalar result, the input can only be a zero-dim vector, which will
+ // be handled by the folder.
+ if (extractResultRank == 0)
+ return failure();
- // Special case if broadcast src is a 0D vector.
- if (extractResultRank == 0) {
- assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
- rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
- return success();
- }
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..8a9204f042ff6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -736,7 +736,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-LABEL: fold_extract_broadcast_0dvec_input_scalar_output
// CHECK-SAME: %[[A:.*]]: vector<f32>
-// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
@@ -2834,7 +2834,7 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%3 = vector.extract %2[] : f32 from vector<f32>
// Broadcast 0D to 3D and extract scalar.
- // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
More information about the Mlir-commits
mailing list