[Mlir-commits] [mlir] 3aeb27d - [mlir][Vector] Fix 0-D tensor vectorization in Linalg
Diego Caballero
llvmlistbot at llvm.org
Fri Jun 16 16:46:20 PDT 2023
Author: Diego Caballero
Date: 2023-06-16T23:45:03Z
New Revision: 3aeb27d69a16e7d8aa3fe684c33d506cdda97b78
URL: https://github.com/llvm/llvm-project/commit/3aeb27d69a16e7d8aa3fe684c33d506cdda97b78
DIFF: https://github.com/llvm/llvm-project/commit/3aeb27d69a16e7d8aa3fe684c33d506cdda97b78.diff
LOG: [mlir][Vector] Fix 0-D tensor vectorization in Linalg
It looks like scalable vector support broke vectorization for 0-D
tensors and we didn't have any test coverting that case. This patch
provides a fix and a test.
Differential Revision: https://reviews.llvm.org/D153181
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 685567d1631be..bbcde44f08618 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1199,38 +1199,43 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
// a. Get the first max ranked shape.
VectorType firstMaxRankedType;
for (Value operand : op->getOperands()) {
- auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
+ auto vecOperand = bvm.lookup(operand);
+ assert(vecOperand && "Vector operand couldn't be found");
+
+ auto vecType = dyn_cast<VectorType>(vecOperand.getType());
if (vecType && (!firstMaxRankedType ||
firstMaxRankedType.getRank() < vecType.getRank()))
firstMaxRankedType = vecType;
}
// b. Broadcast each op if needed.
- SmallVector<Value> vectorizedOperands;
+ SmallVector<Value> vecOperands;
for (Value scalarOperand : op->getOperands()) {
- Value vectorizedOperand = bvm.lookup(scalarOperand);
- auto vecType =
- VectorType::get(firstMaxRankedType.getShape(),
- getElementTypeOrSelf(vectorizedOperand.getType()),
- firstMaxRankedType.getNumScalableDims());
- vectorizedOperands.push_back(
- !firstMaxRankedType
- ? vectorizedOperand
- : broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
+ Value vecOperand = bvm.lookup(scalarOperand);
+ assert(vecOperand && "Vector operand couldn't be found");
+
+ if (firstMaxRankedType) {
+ auto vecType = VectorType::get(firstMaxRankedType.getShape(),
+ getElementTypeOrSelf(vecOperand.getType()),
+ firstMaxRankedType.getNumScalableDims());
+ vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
+ } else {
+ vecOperands.push_back(vecOperand);
+ }
}
// c. for elementwise, the result is the vector with the firstMaxRankedShape
SmallVector<Type> resultTypes;
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
- !firstMaxRankedType
- ? resultType
- : VectorType::get(firstMaxRankedType.getShape(), resultType,
- firstMaxRankedType.getNumScalableDims()));
+ firstMaxRankedType
+ ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+ firstMaxRankedType.getNumScalableDims())
+ : resultType);
}
// d. Build and return the new op.
return VectorizationResult{
VectorizationStatus::NewOp,
- rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- vectorizedOperands, resultTypes, op->getAttrs())};
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
+ resultTypes, op->getAttrs())};
}
/// Generic vectorization function that rewrites the body of a `linalgOp` into
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 404c3492470fa..130c6bcc11abb 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1719,3 +1719,35 @@ transform.sequence failures(propagate) {
%1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
}
+
+// -----
+
+func.func @zero_dim_tensor(%input: tensor<f32>, %output: tensor<f32>) -> tensor<f32>
+{
+ %0 = linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
+ iterator_types = [] }
+ ins(%input : tensor<f32>)
+ outs(%output : tensor<f32>) {
+ ^bb0(%arg0: f32, %arg1: f32):
+ %2 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+ %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @zero_dim_tensor
+// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+// CHECK: vector.extractelement
+// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+// CHECK: vector.extractelement
+// CHECK: arith.addf {{.*}} : f32
+// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
+
More information about the Mlir-commits
mailing list