[Mlir-commits] [mlir] 3a33c14 - [mlir][MemRef] Add a extract_strided_metadata(extract_strided_metadata) pattern

Quentin Colombet llvmlistbot at llvm.org
Fri Oct 14 12:03:18 PDT 2022


Author: Quentin Colombet
Date: 2022-10-14T19:02:10Z
New Revision: 3a33c146edd2a78b2160456e83918a4a042dcc62

URL: https://github.com/llvm/llvm-project/commit/3a33c146edd2a78b2160456e83918a4a042dcc62
DIFF: https://github.com/llvm/llvm-project/commit/3a33c146edd2a78b2160456e83918a4a042dcc62.diff

LOG: [mlir][MemRef] Add a extract_strided_metadata(extract_strided_metadata) pattern

This pattern will be useful to get cleaner code when lowering view like
operations.

Differential Revision: https://reviews.llvm.org/D135836

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 2a8ffba9c32c4..6aa68ae249635 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -718,6 +718,34 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
     return success();
   }
 };
+
+/// Replace `base, offset =
+///            extract_strided_metadata(extract_strided_metadata(src)#0)`
+/// With
+/// ```
+/// base, ... = extract_strided_metadata(src)
+/// offset = 0
+/// ```
+class ExtractStridedMetadataOpExtractStridedMetadataFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  PatternRewriter &rewriter) const override {
+    auto sourceExtractStridedMetadataOp =
+        extractStridedMetadataOp.getSource()
+            .getDefiningOp<memref::ExtractStridedMetadataOp>();
+    if (!sourceExtractStridedMetadataOp)
+      return failure();
+    Location loc = extractStridedMetadataOp.getLoc();
+    rewriter.replaceOp(extractStridedMetadataOp,
+                       {sourceExtractStridedMetadataOp.getBaseBuffer(),
+                        getValueOrCreateConstantIndexOp(
+                            rewriter, loc, rewriter.getIndexAttr(0))});
+    return success();
+  }
+};
 } // namespace
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
@@ -731,7 +759,8 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns(
            ForwardStaticMetadata,
            ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
            ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
-           RewriteExtractAlignedPointerAsIndexOfViewLikeOp>(
+           RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+           ExtractStridedMetadataOpExtractStridedMetadataFolder>(
           patterns.getContext());
 }
 

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 616b835842910..338b52b5ac427 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -841,3 +841,29 @@ func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1
   return %base, %offset :
       memref<i32>, index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata of
+// extract_strided_metadata.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_extract_strided_metadata(
+//  CHECK-SAME: %[[ARG:.*]]: memref<i32>)
+//
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]]
+//
+//       CHECK: return %[[BASE]], %[[C0]]
+func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref<i32>)
+  -> (memref<i32>, index) {
+
+  %base, %offset =
+    memref.extract_strided_metadata %arg:memref<i32>
+    -> memref<i32>, index
+  %base2, %offset2 =
+    memref.extract_strided_metadata %base:memref<i32>
+    -> memref<i32>, index
+
+  return %base2, %offset2 :
+      memref<i32>, index
+}


        


More information about the Mlir-commits mailing list