[Mlir-commits] [mlir] [mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize (PR #93590)

Artem Kroviakov llvmlistbot at llvm.org
Wed Jun 12 06:33:51 PDT 2024


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/93590

>From b1b0384e12550d7545d80d8257889a27831c264f Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 12 Jun 2024 06:33:14 -0700
Subject: [PATCH] [mlir][vector] Use notifyMatchFailure instead of assert in
 VectorLinearize

---
 .../Vector/Transforms/VectorLinearize.cpp     | 35 +++++++++++-------
 mlir/test/Dialect/Vector/linearize.mlir       | 36 +++++++++++++++++--
 2 files changed, 56 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf742f6297..a1bb81e2d11b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final
   LogicalResult
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type dstType = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
-           "scalable vectors are not supported.");
+    VectorType dstType =
+        getTypeConverter()->convertType<VectorType>(extractOp.getType());
+    assert(dstType && "vector type destination expected.");
+    if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
+    VectorType dstType =
+        getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
+    assert(dstType && "vector type destination expected.");
+    // The assert is used because vector.shuffle does not support scalable
+    // vectors.
     assert(!(shuffleOp.getV1VectorType().isScalable() ||
              shuffleOp.getV2VectorType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
+             dstType.isScalable()) &&
            "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
@@ -336,9 +342,10 @@ struct LinearizeVectorExtract final
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (extractOp.getVector().getType().isScalable() ||
+        cast<VectorType>(dstTy).isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -394,10 +401,12 @@ struct LinearizeVectorInsert final
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
-    assert(!(insertOp.getDestVectorType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    VectorType dstTy = getTypeConverter()->convertType<VectorType>(
+        insertOp.getDestVectorType());
+    assert(dstTy && "vector type destination expected.");
+    if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable())
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
                                          targetVectorBitWidth))
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 31a59b809a74b..a3d271a0e1440 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -129,8 +129,8 @@ func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32
 // -----
 
 // ALL-LABEL:   func.func @test_scalable_no_linearize(
-// ALL-SAME:     %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
-func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
+// ALL-SAME:     %[[VAL_0:.*]]: vector<[2]x[2]xf32>,  %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
+func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
   // ALL: %[[CST:.*]] = arith.constant dense<2.000000e+00> : vector<[2]x[2]xf32>
   %0 = arith.constant dense<[[2., 2.], [2., 2.]]> : vector<[2]x[2]xf32>
 
@@ -177,6 +177,17 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
   return %0 : vector<2x2xf32>
 }
 
+// ALL-LABEL:   func.func @test_extract_strided_slice_scalable(
+// ALL-SAME:    %[[VAL_0:.*]]: vector<2x[2]xf32>) -> vector<1x[2]xf32> {
+func.func @test_extract_strided_slice_scalable(%arg0: vector<2x[2]xf32>) -> vector<1x[2]xf32> {
+  // CHECK-NOT: vector.shuffle
+  // CHECK-NOT: vector.shape_cast
+  // ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [1, 2], strides = [1, 1]} : vector<2x[2]xf32> to vector<1x[2]xf32> 
+  %0 = vector.extract_strided_slice %arg0 { sizes = [1, 2], strides = [1, 1], offsets = [1, 0] } : vector<2x[2]xf32> to vector<1x[2]xf32>
+  // ALL: return %[[RES]] : vector<1x[2]xf32>
+  return %0 : vector<1x[2]xf32>
+}
+
 // -----
 // ALL-LABEL: test_extract_strided_slice_2
 // ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
@@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
   return %0 : vector<8x2xf32>
 }
 
+// ALL-LABEL:   func.func @test_vector_extract_scalable(
+// ALL-SAME:    %[[VAL_0:.*]]: vector<2x[2]xf32>) -> f32 {
+func.func @test_vector_extract_scalable(%arg1: vector<2x[2]xf32>) -> f32 {
+  // CHECK-NOT: vector.shuffle
+  // CHECK-NOT: vector.shape_cast
+  // ALL: %[[RES:.*]] = vector.extract %[[VAL_0]][0, 0] : f32 from vector<2x[2]xf32>
+  %0 = vector.extract %arg1[0, 0]: f32 from vector<2x[2]xf32>
+  // ALL: return %[[RES]] : f32
+  return %0 : f32
+}
 // -----
 // ALL-LABEL: test_vector_insert
 // ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -274,3 +295,14 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
   %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
   return %0 : vector<2x8x4xf32>
 }
+
+// ALL-LABEL:   func.func @test_vector_insert_scalable(
+// ALL-SAME:    %[[VAL_0:.*]]: vector<2x[2]xf32>, %[[VAL_1:.*]]: f32) -> vector<2x[2]xf32> {
+func.func @test_vector_insert_scalable(%arg0: vector<2x[2]xf32>, %arg1: f32) -> vector<2x[2]xf32> {
+  // CHECK-NOT: vector.shuffle
+  // CHECK-NOT: vector.shape_cast
+  // ALL: %[[RES:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0, 0] : f32 into vector<2x[2]xf32>
+  %0 = vector.insert %arg1, %arg0[0, 0]: f32 into vector<2x[2]xf32>
+  // ALL: return %[[RES]] : vector<2x[2]xf32>
+  return %0 : vector<2x[2]xf32>
+}



More information about the Mlir-commits mailing list