[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #127706)
Prakhar Dixit
llvmlistbot at llvm.org
Thu Feb 20 04:17:37 PST 2025
https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/127706
>From 5fd68e54eee27860d36c5ae04823b2c6129ced07 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 19 Feb 2025 03:35:46 +0530
Subject: [PATCH 1/2] Add check to ensure input vector rank equals target shape
rank
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..82e473ef7e3b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,8 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ if (originalSize.size() != targetShape->size())
+ return failure();
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
>From bd5e0b154b384e2b66e473bf230d9f3a6363f2ba Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Feb 2025 00:08:16 +0530
Subject: [PATCH 2/2] [vector][mlir] Add required comments and test in
vectorUnroll.cpp and invalid.mlir respectively
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 +++
mlir/test/Dialect/Vector/invalid.mlir | 10 ++++++++++
2 files changed, 13 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 82e473ef7e3b0..42f3fab95d975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return failure();
Location loc = op->getLoc();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..d5e47a39e5107 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
}
+// -----
+
+ func.func @extract_strided_slice() -> () {
+ // expected-error at +1 {{expected input vector rank to match target shape rank}}
+ %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+ vector<24x2x2xf32> to vector<2x2xf32>
+ return
+}
+
// -----
#contraction_accesses = [
More information about the Mlir-commits
mailing list