[Mlir-commits] [mlir] [MLIR] Vector dialect: Address post-merge review comments on #111541 (PR #111552)
Benoit Jacob
llvmlistbot at llvm.org
Tue Oct 8 11:39:58 PDT 2024
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/111552
>From 1f2ef0bccb98036ad00c0c176cb49ae7d79876c4 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 11:10:00 -0500
Subject: [PATCH 1/4] post-review-comments-111541
---
...uous-extract-strided-slice-to-extract.mlir | 21 ++++++-------------
1 file changed, 6 insertions(+), 15 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
index 9147e7bf02581e..ea1bdedaf76286 100644
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -1,34 +1,25 @@
// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
-// CHECK-LABEL: @extract_strided_slice_to_extract_i8
-// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
-// CHECK: return %[[EXTRACT]] : vector<8xi8>
-func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
- %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
- %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
- return %2 : vector<8xi8>
-}
-
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32
+// CHECK-LABEL: @contiguous
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
// CHECK: return %[[EXTRACT]] : vector<4xi32>
-func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
return %2 : vector<4xi32>
}
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
+// CHECK-LABEL: @non_full_size
// CHECK: vector.extract_strided_slice
-func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
return %2 : vector<2xi32>
}
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
+// CHECK-LABEL: @non_contiguous
// CHECK: vector.extract_strided_slice
-func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+func.func @non_contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
return %2 : vector<2xi32>
>From f0b0a7856b45011c83aab9b35116e14fc5a61bda Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 14:05:12 -0400
Subject: [PATCH 2/4] Update
mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
---
.../vector-contiguous-extract-strided-slice-to-extract.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
index ea1bdedaf76286..3c35843e724575 100644
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -19,7 +19,7 @@ func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
// CHECK-LABEL: @non_contiguous
// CHECK: vector.extract_strided_slice
-func.func @non_contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
%2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
return %2 : vector<2xi32>
>From 0c84f441a39cb694c8825d1405ad36c3d61da497 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 13:23:57 -0500
Subject: [PATCH 3/4] review comments
---
...nsertExtractStridedSliceRewritePatterns.cpp | 8 ++++++++
...guous-extract-strided-slice-to-extract.mlir | 18 ++++++++----------
2 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index c2da9347aadc87..2c939362726fb9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -331,6 +331,14 @@ class DecomposeNDExtractStridedSlice
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+/// Before:
+/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+/// After:
+/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
+/// %1 = vector.shape_cast %0 : vector<8xi8> to vector<1x1x1x1x8xi8>
+///
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
index 3c35843e724575..d1401ad7853fc9 100644
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: @contiguous
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK: return %[[EXTRACT]] : vector<4xi32>
+// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
@@ -10,17 +10,15 @@ func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
}
// CHECK-LABEL: @non_full_size
-// CHECK: vector.extract_strided_slice
-func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+// CHECK-NEXT: vector.extract_strided_slice
+func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
- %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
- return %2 : vector<2xi32>
+ return %1 : vector<1x1x1x1x1x2xi32>
}
-// CHECK-LABEL: @non_contiguous
-// CHECK: vector.extract_strided_slice
-func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+// CHECK-LABEL: @non_full_inner_size
+// CHECK-NEXT: vector.extract_strided_slice
+func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
- %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
- return %2 : vector<2xi32>
+ return %1 : vector<1x1x2x1x1x1xi32>
}
>From df97d424f1a0f74481c6fbd78477b69c12d5b52b Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 13:39:42 -0500
Subject: [PATCH 4/4] clang-format
---
.../VectorInsertExtractStridedSliceRewritePatterns.cpp | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 2c939362726fb9..ad845608f18d10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -334,10 +334,13 @@ class DecomposeNDExtractStridedSlice
///
/// Example:
/// Before:
-/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
/// After:
-/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
-/// %1 = vector.shape_cast %0 : vector<8xi8> to vector<1x1x1x1x8xi8>
+/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+/// vector<1x1x1x1x8xi8>
///
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
More information about the Mlir-commits
mailing list