[Mlir-commits] [mlir] ef5a710 - [mlir][vector] Skip 0D vectors in vector linearization. (#87577)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 3 17:01:00 PDT 2024


Author: Han-Chung Wang
Date: 2024-04-03T17:00:56-07:00
New Revision: ef5a7109116c1615a9c99c8dba6577853beb6c73

URL: https://github.com/llvm/llvm-project/commit/ef5a7109116c1615a9c99c8dba6577853beb6c73
DIFF: https://github.com/llvm/llvm-project/commit/ef5a7109116c1615a9c99c8dba6577853beb6c73.diff

LOG: [mlir][vector] Skip 0D vectors in vector linearization. (#87577)

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4fa5b8a4865b4f..b59e9062e5a08e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
     // Reject index since getElementTypeBitWidth will abort for Index types.
     if (!vecType || vecType.getElementType().isIndex())
       return false;
+    // There are no dimension to fold if it is a 0-D vector.
+    if (vecType.getRank() == 0)
+      return false;
     unsigned trailingVecDimBitWidth =
         vecType.getShape().back() * vecType.getElementTypeBitWidth();
     if (trailingVecDimBitWidth >= targetBitWidth)

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index f0e9b3a05c066e..212541c79565b6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -146,6 +146,16 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x
 
 // -----
 
+// ALL-LABEL: func.func @test_0d_vector
+func.func @test_0d_vector() -> vector<f32> {
+  // ALL: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
+  %0 = arith.constant dense<0.0> : vector<f32>
+  // ALL: return %[[CST]]
+  return %0 : vector<f32>
+}
+
+// -----
+
 func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
   // expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
   %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>


        


More information about the Mlir-commits mailing list