[Mlir-commits] [mlir] [mlir]Add a check to ensure bailing out when reducing to a scalar (PR #129694)

Prakhar Dixit llvmlistbot at llvm.org
Wed Mar 5 01:11:25 PST 2025


https://github.com/Prakhar-Dixit updated https://github.com/llvm/llvm-project/pull/129694

>From def5a8bac5e6b6bf1e62a109b15b2783fbc76f89 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Tue, 4 Mar 2025 17:44:35 +0530
Subject: [PATCH 1/3] [mlir][vector] Add a check to ensure bailing out when
 reducing to a scalar, as ExtractStridedSliceOp does not support handling
 scalars

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 08ba972b12ce6..f519484fd56c8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto resultType = reductionOp->getResult(0).getType();
+    if (mlir::isa<mlir::FloatType>(resultType) ||
+        mlir::isa<mlir::IntegerType>(resultType)) {
+      return failure();
+    }
     std::optional<SmallVector<int64_t>> targetShape =
         getTargetShape(options, reductionOp);
     if (!targetShape)

>From 4dde55a4fddaf185c831b7be944604e3459f5df9 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 5 Mar 2025 10:28:26 +0530
Subject: [PATCH 2/3] Add a negative test and modify the return statement

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++---
 mlir/test/Dialect/Vector/vector-unroll-options.mlir | 7 +++++++
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index f519484fd56c8..04c38f9f7b2e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -356,9 +356,9 @@ struct UnrollMultiReductionPattern
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
     auto resultType = reductionOp->getResult(0).getType();
-    if (mlir::isa<mlir::FloatType>(resultType) ||
-        mlir::isa<mlir::IntegerType>(resultType)) {
-      return failure();
+    if (resultType.isIntOrFloat()) {
+      return rewriter.notifyMatchFailure(reductionOp,
+                                         "Unrolling scalars is not supported");
     }
     std::optional<SmallVector<int64_t>> targetShape =
         getTargetShape(options, reductionOp);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 16d30aec7c041..db96e1b66a502 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -222,6 +222,13 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
 //       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   return %[[V2]] : vector<4xf32>
 
+func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
+  %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
+  return %0 : f32
+}
+// CHECK-LABEL: func @negative_vector_multi_reduction
+//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
+//       CHECK:   return %[[R0]] : f32
 
 func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
   %0 = vector.reduction <add>, %v : vector<8xf32> into f32

>From 0cf035c08caf658a20a09c864dbb7820a3a2b0dc Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Wed, 5 Mar 2025 14:41:07 +0530
Subject: [PATCH 3/3] modify test

---
 mlir/test/Dialect/Vector/vector-unroll-options.mlir | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index db96e1b66a502..9c158d05b723c 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -222,13 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
 //       CHECK:   %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 //       CHECK:   return %[[V2]] : vector<4xf32>
 
+// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction
+// operation has already been unrolled, attempting additional unrolling should not be allowed.
 func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
   return %0 : f32
 }
 // CHECK-LABEL: func @negative_vector_multi_reduction
-//       CHECK:   %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
-//       CHECK:   return %[[R0]] : f32
+//  CHECK-NEXT:   %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
+//  CHECK-NEXT:   return %[[R0]] : f32
 
 func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
   %0 = vector.reduction <add>, %v : vector<8xf32> into f32



More information about the Mlir-commits mailing list