[Mlir-commits] [mlir] [mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize (PR #93590)

Artem Kroviakov llvmlistbot at llvm.org
Tue May 28 10:55:52 PDT 2024


https://github.com/akroviakov created https://github.com/llvm/llvm-project/pull/93590

As it was [suggested](https://github.com/llvm/llvm-project/pull/92370#discussion_r1617592942), the `assert` is replaced by `notifyMatchFailure` for improved consistency.

>From e83024b056378880e7c1a7a3b91d751b5b60737e Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 28 May 2024 10:51:22 -0700
Subject: [PATCH] [mlir][vector] Use notifyMatchFailure instead of assert in
 VectorLinearize

---
 .../Vector/Transforms/VectorLinearize.cpp     | 30 +++++++++++--------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf742f6297..840fd384894df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (extractOp.getVector().getType().isScalable() ||
+        cast<VectorType>(dstType).isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -265,10 +266,11 @@ struct LinearizeVectorShuffle final
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (shuffleOp.getV1VectorType().isScalable() ||
+        shuffleOp.getV2VectorType().isScalable() ||
+        cast<VectorType>(dstType).isScalable())
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -336,9 +338,10 @@ struct LinearizeVectorExtract final
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (extractOp.getVector().getType().isScalable() ||
+        cast<VectorType>(dstTy).isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -395,9 +398,10 @@ struct LinearizeVectorInsert final
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
-    assert(!(insertOp.getDestVectorType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (insertOp.getDestVectorType().isScalable() ||
+        cast<VectorType>(dstTy).isScalable())
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
                                          targetVectorBitWidth))



More information about the Mlir-commits mailing list