[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #127706)
Prakhar Dixit
llvmlistbot at llvm.org
Tue Feb 25 02:10:16 PST 2025
https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/127706
>From 06d4ab23c4cf599b67a11306643be1c78f36a57c Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 19 Feb 2025 03:35:46 +0530
Subject: [PATCH 1/6] Add check to ensure input vector rank equals target shape
rank
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..82e473ef7e3b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,8 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ if (originalSize.size() != targetShape->size())
+ return failure();
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
>From b8df120eb02f85cfb6cb72c12971e8f58de0e6bb Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Feb 2025 00:08:16 +0530
Subject: [PATCH 2/6] [vector][mlir] Add required comments and test in
vectorUnroll.cpp and invalid.mlir respectively
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 +++
mlir/test/Dialect/Vector/invalid.mlir | 10 ++++++++++
2 files changed, 13 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 82e473ef7e3b0..42f3fab95d975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ // Bail-out if rank(source) != rank(target). The main limitation here is the
+ // fact that `ExtractStridedSlice` requires the rank for the input and
+ // output to match. If needed, we can relax this later.
if (originalSize.size() != targetShape->size())
return failure();
Location loc = op->getLoc();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..d5e47a39e5107 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
}
+// -----
+
+ func.func @extract_strided_slice() -> () {
+ // expected-error at +1 {{expected input vector rank to match target shape rank}}
+ %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+ vector<24x2x2xf32> to vector<2x2xf32>
+ return
+}
+
// -----
#contraction_accesses = [
>From 04d4303dd872c3bf915ea689649b29d3f36b733e Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 13:37:12 +0530
Subject: [PATCH 3/6] Add a negative test and also update the diagnostic test
with lesser ops
---
mlir/test/Dialect/Vector/invalid.mlir | 7 +++----
.../test/Dialect/Vector/vector-unroll-options.mlir | 14 ++++++++++++++
2 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d5e47a39e5107..7c1892b8b5f53 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -768,11 +768,10 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
- func.func @extract_strided_slice() -> () {
+func.func @extract_strided_slice(%arg0: vector<3x2x2xf32>) {
// expected-error at +1 {{expected input vector rank to match target shape rank}}
- %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
- %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
- vector<24x2x2xf32> to vector<2x2xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+ vector<3x2x2xf32> to vector<2x2xf32>
return
}
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 7e3fe56f6b124..d1dd65727e204 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,20 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+func.func @higher_rank_unroll() {
+ %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
+ %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+ %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
+ %818 = scf.execute_region -> vector<24x2x2xf32> {
+ scf.yield %47 : vector<24x2x2xf32>
+ }
+ %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
+ return
+}
+
+// CHECK-LABEL: func @higher_rank_unroll
+// CHECK: return
+
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
return %0 : vector<4xf32>
>From 74ec44ae58222c21e61fa385e62d1d8c118e0d92 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 15:40:18 +0530
Subject: [PATCH 4/6] modify test
---
.../Dialect/Vector/vector-unroll-options.mlir | 17 ++++++-----------
1 file changed, 6 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index d1dd65727e204..9485298aa2d3e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,18 +188,13 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-func.func @higher_rank_unroll() {
- %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
- %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
- %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
- %818 = scf.execute_region -> vector<24x2x2xf32> {
- scf.yield %47 : vector<24x2x2xf32>
- }
- %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
- return
+func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
+ %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+ return
}
-
-// CHECK-LABEL: func @higher_rank_unroll
+// CHECK-LABEL: func @negative_vector_fma_3d
+// CHECK-NOT: vector.extract_strided_slice
+// CHECK: %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
// CHECK: return
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
>From e0860bd0702f6ecc937ae5ec73311d3a917b8fd2 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Tue, 25 Feb 2025 10:36:01 +0530
Subject: [PATCH 5/6] Add a comment for better clarity in the negative test
case
---
mlir/test/Dialect/Vector/vector-unroll-options.mlir | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9485298aa2d3e..a8f734795a418 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,7 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
+// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return
>From 62961c2b138a121ef5b9a386b262dc9c7752cff0 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Tue, 25 Feb 2025 15:39:04 +0530
Subject: [PATCH 6/6] modify test
---
mlir/test/Dialect/Vector/vector-unroll-options.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index a8f734795a418..4c104d437ae82 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -195,7 +195,7 @@ func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
}
// CHECK-LABEL: func @negative_vector_fma_3d
// CHECK-NOT: vector.extract_strided_slice
-// CHECK: %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+// CHECK: %[[R0:.*]] = vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<3x2x2xf32>
// CHECK: return
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
More information about the Mlir-commits
mailing list