[Mlir-commits] [mlir] [mlir][Vector] Fix "scalability" in CastAwayExtractStridedSliceLeadingOneDim (PR #81187)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Feb 8 12:25:14 PST 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/81187

Makes sure that "scalability" flags in the
`CastAwayExtractStridedSliceLeadingOneDim` pattern are correctly
updated.


>From cb1f4782ee5c2118c56f01a396547309a4466f2d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 8 Feb 2024 20:18:23 +0000
Subject: [PATCH] [mlir][Vector] Fix "scalability" in
 CastAwayExtractStridedSliceLeadingOneDim

Makes sure that "scalability" flags in the
`CastAwayExtractStridedSliceLeadingOneDim` pattern are correctly
updated.
---
 .../Transforms/VectorDropLeadUnitDim.cpp      |  3 ++-
 .../vector-dropleadunitdim-transforms.mlir    | 21 +++++++++++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index e1ed5d81625d8e..74382b027c2f48 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -73,7 +73,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     VectorType oldDstType = extractOp.getType();
     VectorType newDstType =
         VectorType::get(oldDstType.getShape().drop_front(dropCount),
-                        oldDstType.getElementType());
+                        oldDstType.getElementType(),
+                        oldDstType.getScalableDims().drop_front(dropCount));
 
     Location loc = extractOp.getLoc();
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index f601be04168144..bb2d30f2092435 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -206,6 +206,16 @@ func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8x
   return %0: vector<1x1x8xf16>
 }
 
+// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
+func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
+  // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
+  // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
+  // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x1x[8]xf16>
+}
+
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
 func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
   // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
@@ -217,6 +227,17 @@ func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16
   return %0: vector<1x8x8xf16>
 }
 
+// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
+func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
+  // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
+  // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
+  // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
+  // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
+  // CHECK: return %[[RET]]
+  return %0: vector<1x8x[8]xf16>
+}
+
 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {



More information about the Mlir-commits mailing list