[Mlir-commits] [mlir] [mlir][vector] Add a check to ensure input vector rank equals target shape rank (PR #127706)

Prakhar Dixit llvmlistbot at llvm.org
Fri Feb 21 02:12:18 PST 2025


https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/127706

>From b5715c0ffbcc16cd22e26d5975c332381846b629 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 19 Feb 2025 03:35:46 +0530
Subject: [PATCH 1/4] Add check to ensure input vector rank equals target shape
 rank

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index c1e3850f05c5e..82e473ef7e3b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,8 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    if (originalSize.size() != targetShape->size())
+      return failure();
     Location loc = op->getLoc();
     // Prepare the result vector.
     Value result = rewriter.create<arith::ConstantOp>(

>From 8717cbac20bdc2479be931370a9078cd9dfa13ed Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Feb 2025 00:08:16 +0530
Subject: [PATCH 2/4] [vector][mlir] Add required comments and test in
 vectorUnroll.cpp and invalid.mlir respectively

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp |  3 +++
 mlir/test/Dialect/Vector/invalid.mlir               | 10 ++++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 82e473ef7e3b0..42f3fab95d975 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -437,6 +437,9 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    // Bail-out if rank(source) != rank(target). The main limitation here is the
+    // fact that `ExtractStridedSlice` requires the rank for the input and
+    // output to match. If needed, we can relax this later.
     if (originalSize.size() != targetShape->size())
       return failure();
     Location loc = op->getLoc();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..d5e47a39e5107 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -766,6 +766,16 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
 }
 
+// -----
+
+ func.func @extract_strided_slice() -> () {
+  // expected-error at +1 {{expected input vector rank to match target shape rank}}
+  %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+  %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+         vector<24x2x2xf32> to vector<2x2xf32>
+  return
+}
+
 // -----
 
 #contraction_accesses = [

>From de2edb06b14b42e5212028ad4ac608340a730f1c Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 13:37:12 +0530
Subject: [PATCH 3/4] Add a negative test and also update the diagnostic test
 with lesser ops

---
 mlir/test/Dialect/Vector/invalid.mlir              |  7 +++----
 .../test/Dialect/Vector/vector-unroll-options.mlir | 14 ++++++++++++++
 2 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d5e47a39e5107..7c1892b8b5f53 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -768,11 +768,10 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
- func.func @extract_strided_slice() -> () {
+func.func @extract_strided_slice(%arg0: vector<3x2x2xf32>) {
   // expected-error at +1 {{expected input vector rank to match target shape rank}}
-  %0 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
-  %1 = vector.extract_strided_slice %0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
-         vector<24x2x2xf32> to vector<2x2xf32>
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}:
+         vector<3x2x2xf32> to vector<2x2xf32>
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 7e3fe56f6b124..d1dd65727e204 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,6 +188,20 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
+func.func @higher_rank_unroll() {
+  %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
+  %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
+  %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
+  %818 = scf.execute_region -> vector<24x2x2xf32> {
+      scf.yield %47 : vector<24x2x2xf32>
+    }
+  %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
+  return
+}
+
+// CHECK-LABEL: func @higher_rank_unroll
+//       CHECK: return 
+
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
   return %0 : vector<4xf32>

>From 88063306e7c7f9d16c21366f9664821c6db11253 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Fri, 21 Feb 2025 15:40:18 +0530
Subject: [PATCH 4/4] modify test

---
 .../Dialect/Vector/vector-unroll-options.mlir   | 17 ++++++-----------
 1 file changed, 6 insertions(+), 11 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index d1dd65727e204..9485298aa2d3e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,18 +188,13 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-func.func @higher_rank_unroll() {
-  %cst_25 = arith.constant dense<3.718400e+04> : vector<4x2x2xf16>
-  %cst_26 = arith.constant dense<1.000000e+00> : vector<24x2x2xf32>
-  %47 = vector.fma %cst_26, %cst_26, %cst_26 : vector<24x2x2xf32>
-  %818 = scf.execute_region -> vector<24x2x2xf32> {
-      scf.yield %47 : vector<24x2x2xf32>
-    }
-  %823 = vector.extract_strided_slice %cst_25 {offsets = [2], sizes = [1], strides = [1]} : vector<4x2x2xf16> to vector<1x2x2xf16>
-  return
+func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) {
+  %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
+  return 
 }
-
-// CHECK-LABEL: func @higher_rank_unroll
+// CHECK-LABEL: func @negative_vector_fma_3d
+//   CHECK-NOT: vector.extract_strided_slice
+//       CHECK: %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
 //       CHECK: return 
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {



More information about the Mlir-commits mailing list