[Mlir-commits] [mlir] [mlir][Linalg] Remove implicit zero rank vectors in vectorization (PR #116069)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 07:54:35 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.

---
Full diff: https://github.com/llvm/llvm-project/pull/116069.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+9-9) 
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..9f35e40a964af6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -590,9 +590,6 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
 /// otherwise.
 static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
   auto dstVecType = dyn_cast<VectorType>(dstType);
-  // If no shape to broadcast to, just return `value`.
-  if (dstVecType.getRank() == 0)
-    return value;
   if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
       vector::BroadcastableToResult::Success)
     return value;
@@ -608,6 +605,15 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
 static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
                                       Value valueToReduce, Value acc,
                                       ArrayRef<bool> dimsToMask) {
+  // If `acc` is a zero-rank vector, extract the scalar value from it, since
+  // vector.multi_reduction does not support 0 rank vectors yet.
+  // TODO: Remove this once vector.multi_reduction supports 0 rank vectors.
+  auto accVecType = dyn_cast<VectorType>(acc.getType());
+  if (accVecType && accVecType.getRank() == 0) {
+    acc = b.create<vector::ExtractOp>(reduceOp->getLoc(), acc,
+                                      ArrayRef<int64_t>());
+  }
+
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
@@ -1410,12 +1416,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
     }
 
-    // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
-    // TODO: remove this.
-    if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
-                                                     ArrayRef<int64_t>());
-
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
     bvm.map(bbarg, readValue);
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..ee18610071eb20 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1777,11 +1777,8 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func @zero_dim_tensor
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
 //       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-//       CHECK:     vector.extract
-//       CHECK:     arith.addf {{.*}} : f32
-//       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
+//       CHECK:     arith.addf {{.*}} : vector<f32>
 //       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/116069


More information about the Mlir-commits mailing list