[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #127706)
Prakhar Dixit
llvmlistbot at llvm.org
Fri Feb 21 02:12:18 PST 2025
https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/127706
>From b5715c0ffbcc16cd22e26d5975c332381846b629 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 19 Feb 2025 03:35:46 +0530
Subject: [PATCH 1/4] Add check to ensure input vector rank equals target shape
rank
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..82e473ef7e3b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,8 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ if (originalSize.size() != targetShape->size())
+ return failure();
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
>From 8717cbac20bdc2479be931370a9078cd9dfa13ed Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Feb 2025 00:08:16 +0530
Subject: [PATCH 2/4] [vector][mlir] Add required comments and test in
vectorUnroll.cpp and invalid.mlir respectively
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 +++
mlir/test/Dialect/Vector/invalid.mlir | 10 ++++++++++
2 files changed, 13 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 82e473ef7e3b0..42f3fab95d975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return failure();
Location loc = op->getLoc();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..d5e47a39e5107 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
}
+// -----
+
+ func.func @extract_strided_slice() -> () {
+ // expected-error at +1 {{expected input vector rank to match target shape rank}}
+ %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+ vector<24x2x2xf32> to vector<2x2xf32>
+ return
+}
+
// -----
#contraction_accesses = [
>From de2edb06b14b42e5212028ad4ac608340a730f1c Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 13:37:12 +0530
Subject: [PATCH 3/4] Add a negative test and also update the diagnostic test
with lesser ops
---
mlir/test/Dialect/Vector/invalid.mlir | 7 +++----
.../test/Dialect/Vector/vector-unroll-options.mlir | 14 ++++++++++++++
2 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d5e47a39e5107..7c1892b8b5f53 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -768,11 +768,10 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
- func.func @extract_strided_slice() -> () {
+func.func @extract_strided_slice(%arg0: vector<3x2x2xf32>) {
// expected-error at +1 {{expected input vector rank to match target shape rank}}
- %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
- %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
- vector<24x2x2xf32> to vector<2x2xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+ vector<3x2x2xf32> to vector<2x2xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 7e3fe56f6b124..d1dd65727e204 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,20 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+func.func @higher_rank_unroll() {
+ %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
+ %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+ %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
+ %818 = scf.execute_region -> vector<24x2x2xf32> {
+ scf.yield %47 : vector<24x2x2xf32>
+ }
+ %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
+ return
+}
+
+// CHECK-LABEL: func @higher_rank_unroll
+// CHECK: return
+
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
>From 88063306e7c7f9d16c21366f9664821c6db11253 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 15:40:18 +0530
Subject: [PATCH 4/4] modify test
---
.../Dialect/Vector/vector-unroll-options.mlir | 17 ++++++-----------
1 file changed, 6 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index d1dd65727e204..9485298aa2d3e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,18 +188,13 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-func.func @higher_rank_unroll() {
- %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
- %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
- %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
- %818 = scf.execute_region -> vector<24x2x2xf32> {
- scf.yield %47 : vector<24x2x2xf32>
- }
- %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
- return
+func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
+ %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+ return
}
-
-// CHECK-LABEL: func @higher_rank_unroll
+// CHECK-LABEL: func @negative_vector_fma_3d
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
// CHECK: return
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
More information about the Mlir-commits
mailing list