[Mlir-commits] [mlir] [mlir][Linalg] Remove implicit zero rank vectors in vectorization (PR #116069)

Kunwar Grover llvmlistbot at llvm.org
Wed Nov 13 07:54:00 PST 2024


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/116069

Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.

>From 8023ff66d04a2b7ee87754e354a22e6c1944d7e7 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 13 Nov 2024 15:44:49 +0000
Subject: [PATCH] Remove implcit zero rank vectors in vectorization

---
 .../Linalg/Transforms/Vectorization.cpp        | 18 +++++++++---------
 .../Linalg/vectorization-with-patterns.mlir    |  5 +----
 2 files changed, 10 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..9f35e40a964af6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -590,9 +590,6 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
 /// otherwise.
 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
   auto dstVecType = dyn_cast<VectorType>(dstType);
-  // If no shape to broadcast to, just return `value`.
-  if (dstVecType.getRank() == 0)
-    return value;
   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
       vector::BroadcastableToResult::Success)
     return value;
@@ -608,6 +605,15 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
                                       Value valueToReduce, Value acc,
                                       ArrayRef<bool> dimsToMask) {
+  // If `acc` is a zero-rank vector, extract the scalar value from it, since
+  // vector.multi_reduction does not support 0 rank vectors yet.
+  // TODO: Remove this once vector.multi_reduction supports 0 rank vectors.
+  auto accVecType = dyn_cast<VectorType>(acc.getType());
+  if (accVecType && accVecType.getRank() == 0) {
+    acc = b.create<vector::ExtractOp>(reduceOp->getLoc(), acc,
+                                      ArrayRef<int64_t>());
+  }
+
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
@@ -1410,12 +1416,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
     }
 
-    // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
-    // TODO: remove this.
-    if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
-                                                     ArrayRef<int64_t>());
-
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
     bvm.map(bbarg, readValue);
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..ee18610071eb20 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1777,11 +1777,8 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func @zero_dim_tensor
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
-//       CHECK:     arith.addf {{.*}} : f32
-//       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
+//       CHECK:     arith.addf {{.*}} : vector<f32>
 //       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
 
 // -----



More information about the Mlir-commits mailing list