[Mlir-commits] [mlir] ef5a710 - [mlir][vector] Skip 0D vectors in vector linearization. (#87577)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 3 17:01:00 PDT 2024
Author: Han-Chung Wang
Date: 2024-04-03T17:00:56-07:00
New Revision: ef5a7109116c1615a9c99c8dba6577853beb6c73
URL: https://github.com/llvm/llvm-project/commit/ef5a7109116c1615a9c99c8dba6577853beb6c73
DIFF: https://github.com/llvm/llvm-project/commit/ef5a7109116c1615a9c99c8dba6577853beb6c73.diff
LOG: [mlir][vector] Skip 0D vectors in vector linearization. (#87577)
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
mlir/test/Dialect/Vector/linearize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4fa5b8a4865b4f..b59e9062e5a08e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
if (trailingVecDimBitWidth >= targetBitWidth)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index f0e9b3a05c066e..212541c79565b6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -146,6 +146,16 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
// -----
+// ALL-LABEL: func.func @test_0d_vector
+func.func @test_0d_vector() -> vector<f32> {
+ // ALL: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
+ %0 = arith.constant dense<0.0> : vector<f32>
+ // ALL: return %[[CST]]
+ return %0 : vector<f32>
+}
+
+// -----
+
func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
// expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
%0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
More information about the Mlir-commits
mailing list