[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #127706)

Prakhar Dixit llvmlistbot at llvm.org
Thu Feb 20 04:17:37 PST 2025


https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/127706

>From 5fd68e54eee27860d36c5ae04823b2c6129ced07 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 19 Feb 2025 03:35:46 +0530
Subject: [PATCH 1/2] Add check to ensure input vector rank equals target shape
 rank

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..82e473ef7e3b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,8 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    if (originalSize.size() != targetShape->size())
+      return failure();
     Location loc = op->getLoc();
     // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(

>From bd5e0b154b384e2b66e473bf230d9f3a6363f2ba Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Feb 2025 00:08:16 +0530
Subject: [PATCH 2/2] [vector][mlir] Add required comments and test in
 vectorUnroll.cpp and invalid.mlir respectively

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp |  3 +++
 mlir/test/Dialect/Vector/invalid.mlir               | 10 ++++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 82e473ef7e3b0..42f3fab95d975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
     if (originalSize.size() != targetShape->size())
       return failure();
     Location loc = op->getLoc();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..d5e47a39e5107 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
 }
 
+// -----
+
+ func.func @extract_strided_slice() -> () {
+  // expected-error at +1 {{expected input vector rank to match target shape rank}}
+  %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+  %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+         vector<24x2x2xf32> to vector<2x2xf32>
+  return
+}
+
 // -----
 
 #contraction_accesses = [



More information about the Mlir-commits mailing list