[Mlir-commits] [mlir] f7e1ce0 - [mlir][MemRef] Add pattern that forwards constant strided metadata.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Sep 26 08:34:40 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-26T08:34:31-07:00
New Revision: f7e1ce0f3071849fe9c7932a92038c51105fd8bc

URL: https://github.com/llvm/llvm-project/commit/f7e1ce0f3071849fe9c7932a92038c51105fd8bc
DIFF: https://github.com/llvm/llvm-project/commit/f7e1ce0f3071849fe9c7932a92038c51105fd8bc.diff

LOG: [mlir][MemRef] Add pattern that forwards constant strided metadata.

`memref.extract_strided_metadata` can forward constants independently of the
exsistence of other operations such as subview or reshape.

Differential Revision: https://reviews.llvm.org/D134603

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 16d17aa0b183b..dcadd5b33d078 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
@@ -428,13 +429,72 @@ struct ExtractStridedMetadataOpExpandShapeFolder
     return success();
   }
 };
+
+/// Helper function to perform the replacement of all constant uses of `values`
+/// by a materialized constant extracted from `maybeConstants`.
+/// `values` and `maybeConstants` are expected to have the same size.
+template <typename Container>
+bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc,
+                           Container values, ArrayRef<int64_t> maybeConstants,
+                           llvm::function_ref<bool(int64_t)> isDynamic) {
+  assert(values.size() == maybeConstants.size() &&
+         " expected values and maybeConstants of the same size");
+  bool atLeastOneReplacement = false;
+  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
+    // Don't materialize a constant if there are no uses: this would indice
+    // infinite loops in the driver.
+    if (isDynamic(maybeConstant) || result.use_empty())
+      continue;
+    Value constantVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+    for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
+      rewriter.startRootUpdate(op);
+      // updateRootInplace: lambda cannot capture structured bindings in C++17
+      // yet.
+      op->replaceUsesOfWith(result, constantVal);
+      rewriter.finalizeRootUpdate(op);
+      atLeastOneReplacement = true;
+    }
+  }
+  return atLeastOneReplacement;
+}
+
+// Forward propagate all constants information from an ExtractStridedMetadataOp.
+struct ForwardStaticMetadata
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
+                                PatternRewriter &rewriter) const override {
+    auto memrefType = metadataOp.getSource().getType().cast<MemRefType>();
+    SmallVector<int64_t> strides;
+    int64_t offset;
+    LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
+    (void)res;
+    assert(succeeded(res) && "must be a strided memref type");
+
+    bool atLeastOneReplacement = replaceConstantUsesOf(
+        rewriter, metadataOp.getLoc(),
+        ArrayRef<TypedValue<IndexType>>(metadataOp.getOffset()),
+        ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
+    atLeastOneReplacement |= replaceConstantUsesOf(
+        rewriter, metadataOp.getLoc(), metadataOp.getSizes(),
+        memrefType.getShape(), ShapedType::isDynamic);
+    atLeastOneReplacement |= replaceConstantUsesOf(
+        rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides,
+        ShapedType::isDynamicStrideOrOffset);
+
+    return success(atLeastOneReplacement);
+  }
+};
 } // namespace
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpExpandShapeFolder>(
-      patterns.getContext());
+  patterns
+      .add<ExtractStridedMetadataOpSubviewFolder,
+           ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata>(
+          patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 3b2b00d2dc6fa..dd5889fe47fd9 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -1,5 +1,24 @@
 // RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s
 
+// CHECK-LABEL: func @extract_strided_metadata_constants
+//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+    -> (memref<f32>, index, index, index, index, index) {
+  //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  //   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  //   CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+  
+  //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
+    memref<5x4xf32, strided<[4,1], offset:2>>
+    -> memref<f32>, index, index, index, index, index
+
+  // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]]
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
 // -----
 
 // Check that we simplify extract_strided_metadata of subview to


        


More information about the Mlir-commits mailing list