[Mlir-commits] [mlir] d905b1c - [MLIR] Vector dialect: Address post-merge review comments on #111541 (#111552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 8 12:35:18 PDT 2024


Author: Benoit Jacob
Date: 2024-10-08T15:35:16-04:00
New Revision: d905b1caf14d51ebdc67a3c114a2265d479f818c

URL: https://github.com/llvm/llvm-project/commit/d905b1caf14d51ebdc67a3c114a2265d479f818c
DIFF: https://github.com/llvm/llvm-project/commit/d905b1caf14d51ebdc67a3c114a2265d479f818c.diff

LOG: [MLIR] Vector dialect: Address post-merge review comments on #111541 (#111552)

Co-authored-by: Andrzej WarzyƄski <andrzej.warzynski at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index c2da9347aadc87..ad845608f18d10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -331,6 +331,17 @@ class DecomposeNDExtractStridedSlice
 
 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
 /// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+///     Before:
+///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+///     After:
+///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+///         vector<1x1x1x1x8xi8>
+///
 class ContiguousExtractStridedSliceToExtract final
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:

diff  --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
index 9147e7bf02581e..d1401ad7853fc9 100644
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -1,35 +1,24 @@
 // RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
 
-// CHECK-LABEL: @extract_strided_slice_to_extract_i8
-// CHECK:       %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
-// CHECK:       return %[[EXTRACT]] :  vector<8xi8>
-func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
-  %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
-  return %2 : vector<8xi8>
-}
-
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32
+// CHECK-LABEL: @contiguous
 // CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK:       return %[[EXTRACT]] :  vector<4xi32>
-func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
+func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
   %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
   return %2 : vector<4xi32>
 }
 
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
-// CHECK:        vector.extract_strided_slice
-func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+// CHECK-LABEL: @non_full_size
+// CHECK-NEXT:   vector.extract_strided_slice
+func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
-  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
-  return %2 : vector<2xi32>
+  return %1 : vector<1x1x1x1x1x2xi32>
 }
 
-// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
-// CHECK:        vector.extract_strided_slice
-func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+// CHECK-LABEL: @non_full_inner_size
+// CHECK-NEXT:    vector.extract_strided_slice
+func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
-  %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
-  return %2 : vector<2xi32>
+  return %1 : vector<1x1x2x1x1x1xi32>
 }


        


More information about the Mlir-commits mailing list