[Mlir-commits] [mlir] [mlir][Vector] Fix an assertion on failing cast in vector-transfer-flatten-patterns (PR #86030)
Balaji V. Iyer.
llvmlistbot at llvm.org
Mon Mar 25 09:39:50 PDT 2024
https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/86030
>From afc2639af07bff07d887cfa38dd6eb872762d9cb Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Mar 2024 20:05:42 +0000
Subject: [PATCH 1/4] Added a dynamic check for VectorType.
When the result is not a vectorType, there is an assert.
This patch will do the check and bail when the result is
not a VectorType.
---
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7ca03537049812..38536de43f13f2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -22,9 +22,9 @@ using namespace mlir;
static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
auto resultTypes = op->getResultTypes();
for (auto resType : resultTypes) {
- VectorType vecType = cast<VectorType>(resType);
+ VectorType vecType = dyn_cast<VectorType>(resType);
// Reject index since getElementTypeBitWidth will abort for Index types.
- if (vecType.getElementType().isIndex())
+ if (!vecType || vecType.getElementType().isIndex())
return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
>From a99587bc6af5597ea09ea386ac4073b25fcab4e9 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 21 Mar 2024 21:36:33 +0000
Subject: [PATCH 2/4] Added a test
---
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed8..82b4115ea34e35 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -493,3 +493,14 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+// This test exists to make sure it doesn't hit an assert and compiles through.
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0, %arg0 : tensor<4xf32>, tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @simple_mul
+// CHECK-128B-LABEL: func.func @simple_mul
\ No newline at end of file
>From aa838792d77604b3fe916f93a6767b23dcbe35c8 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Fri, 22 Mar 2024 18:54:50 +0000
Subject: [PATCH 3/4] Added test into the correct file
---
mlir/test/Dialect/Vector/linearize.mlir | 13 +++++++++++++
.../Dialect/Vector/vector-transfer-flatten.mlir | 10 ----------
2 files changed, 13 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2cbf9bec7a4136..aa7887a5dcfc06 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -90,3 +90,16 @@ func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
return %0 : vector<2x2xindex>
}
+
+// -----
+
+// This test exists to make sure it doesn't hit an assert and compiles through.
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0, %arg0 : tensor<4xf32>, tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @simple_mul
+// CHECK128-LABEL: func.func @simple_mul
+// CHECK0-LABEL: func.func @simple_mul
+
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 82b4115ea34e35..23c59ed68d93f7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -494,13 +494,3 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
// CHECK-128B-NOT: memref.collapse_shape
-// -----
-
-// This test exists to make sure it doesn't hit an assert and compiles through.
-func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
- %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
- return %0, %arg0 : tensor<4xf32>, tensor<4xf32>
-}
-
-// CHECK-LABEL: func.func @simple_mul
-// CHECK-128B-LABEL: func.func @simple_mul
\ No newline at end of file
>From add5cebd7c5d43f9661d09f84e06119d402f6fd7 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Mon, 25 Mar 2024 16:39:30 +0000
Subject: [PATCH 4/4] Added the changes requested by Andrej
---
mlir/test/Dialect/Vector/linearize.mlir | 17 ++++++++++-------
.../Dialect/Vector/vector-transfer-flatten.mlir | 1 -
2 files changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index aa7887a5dcfc06..549f4d042fa259 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -93,13 +93,16 @@ func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi
// -----
-// This test exists to make sure it doesn't hit an assert and compiles through.
-func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+// CHECK-LABEL: func.func @nonvec_result
+// CHECK128-LABEL: func.func @nonvec_result
+// CHECK0-LABEL: func.func @nonvec_result
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>)
+// CHECK128-SAME: (%[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>)
+// CHECK0-SAME: (%[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>)
+func.func @nonvec_result(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+ // CHECK: %[[MULF:.*]] = arith.mulf %[[ARG0]], %[[ARG1]]
+ // CHECK128: %[[MULF:.*]] = arith.mulf %[[ARG0]], %[[ARG1]]
+ // CHECK0: %[[MULF:.*]] = arith.mulf %[[ARG0]], %[[ARG1]]
%0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
return %0, %arg0 : tensor<4xf32>, tensor<4xf32>
}
-
-// CHECK-LABEL: func.func @simple_mul
-// CHECK128-LABEL: func.func @simple_mul
-// CHECK0-LABEL: func.func @simple_mul
-
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 23c59ed68d93f7..788ae9ac044ed8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -493,4 +493,3 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
// CHECK-128B-NOT: memref.collapse_shape
-
More information about the Mlir-commits
mailing list