[Mlir-commits] [mlir] [mlir][Vector] Added a dynamic check for VectorType. (PR #86030)

Balaji V. Iyer. llvmlistbot at llvm.org
Thu Mar 21 14:36:53 PDT 2024


https://github.com/bviyer updated https://github.com/llvm/llvm-project/pull/86030

>From afc2639af07bff07d887cfa38dd6eb872762d9cb Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Wed, 20 Mar 2024 20:05:42 +0000
Subject: [PATCH 1/2] Added a dynamic check for VectorType.

When the result is not a vectorType, there is an assert.
This patch will do the check and bail when the result is
not a VectorType.
---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7ca03537049812..38536de43f13f2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -22,9 +22,9 @@ using namespace mlir;
 static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
   auto resultTypes = op->getResultTypes();
   for (auto resType : resultTypes) {
-    VectorType vecType = cast<VectorType>(resType);
+    VectorType vecType = dyn_cast<VectorType>(resType);
     // Reject index since getElementTypeBitWidth will abort for Index types.
-    if (vecType.getElementType().isIndex())
+    if (!vecType || vecType.getElementType().isIndex())
       return false;
     unsigned trailingVecDimBitWidth =
         vecType.getShape().back() * vecType.getElementTypeBitWidth();

>From a99587bc6af5597ea09ea386ac4073b25fcab4e9 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer" <bviyer at gmail.com>
Date: Thu, 21 Mar 2024 21:36:33 +0000
Subject: [PATCH 2/2] Added a test

---
 mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 788ae9ac044ed8..82b4115ea34e35 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -493,3 +493,14 @@ func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
 
 // CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
 //   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// This test exists to make sure it doesn't hit an assert and compiles through.
+func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+    %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+    return %0, %arg0 : tensor<4xf32>, tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @simple_mul
+// CHECK-128B-LABEL: func.func @simple_mul
\ No newline at end of file



More information about the Mlir-commits mailing list