[Mlir-commits] [mlir] [mlir][Linalg] Remove implicit zero rank vectors in vectorization (PR #116069)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 13 07:54:35 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
Vectorization today converts any zero rank vector it encounters into a scalar. This patch moves this check from all operations, to only operations that do not support zero-rank operations yet. For linalg vectorization, this is primarily vector::MultiDimReductionOp and vector::ContractionOp.
---
Full diff: https://github.com/llvm/llvm-project/pull/116069.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+9-9)
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-4)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 23b46a2ee55f8d..9f35e40a964af6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -590,9 +590,6 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) {
/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
auto dstVecType = dyn_cast<VectorType>(dstType);
- // If no shape to broadcast to, just return `value`.
- if (dstVecType.getRank() == 0)
- return value;
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
vector::BroadcastableToResult::Success)
return value;
@@ -608,6 +605,15 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
Value valueToReduce, Value acc,
ArrayRef<bool> dimsToMask) {
+ // If `acc` is a zero-rank vector, extract the scalar value from it, since
+ // vector.multi_reduction does not support 0 rank vectors yet.
+ // TODO: Remove this once vector.multi_reduction supports 0 rank vectors.
+ auto accVecType = dyn_cast<VectorType>(acc.getType());
+ if (accVecType && accVecType.getRank() == 0) {
+ acc = b.create<vector::ExtractOp>(reduceOp->getLoc(), acc,
+ ArrayRef<int64_t>());
+ }
+
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
@@ -1410,12 +1416,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
- // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
- // TODO: remove this.
- if (readType.getRank() == 0)
- readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
- ArrayRef<int64_t>());
-
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
<< "\n");
bvm.map(bbarg, readValue);
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..ee18610071eb20 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1777,11 +1777,8 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func @zero_dim_tensor
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-// CHECK: vector.extract
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
-// CHECK: vector.extract
-// CHECK: arith.addf {{.*}} : f32
-// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+// CHECK: arith.addf {{.*}} : vector<f32>
// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/116069
More information about the Mlir-commits
mailing list