[Mlir-commits] [mlir] [mlir] alloc-to-alloca conversion for memref (PR #65335)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Sep 5 08:27:46 PDT 2023


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/65335:

Introduce a simple conversion of a memref.alloc/dealloc pair into an alloca in the same scope. Expose it as a transform op and a pattern.

Allocas typically lower to stack allocations as opposed to alloc/dealloc that lower to significantly more expensive malloc/free calls. In addition, this can be combined with allocation hoisting from loops to further improve performance.

>From 8c2c5dd19dcef41907dbfcfd7cd32fdb6db89ebb Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 5 Sep 2023 15:04:52 +0000
Subject: [PATCH] [mlir] alloc-to-alloca conversion for memref

Introduce a simple conversion of a memref.alloc/dealloc pair into an
alloca in the same scope. Expose it as a transform op and a pattern.

Allocas typically lower to stack allocations as opposed to alloc/dealloc
that lower to significantly more expensive malloc/free calls. In
addition, this can be combined with allocation hoisting from loops to
further improve performance.
---
 .../MemRef/TransformOps/MemRefTransformOps.td | 17 +++++
 .../Dialect/MemRef/Transforms/Transforms.h    | 11 +++
 .../Transform/IR/TransformInterfaces.td       | 12 ++++
 .../TransformOps/MemRefTransformOps.cpp       | 37 ++++++++++
 .../Transforms/IndependenceTransforms.cpp     | 24 +++++++
 .../lib/Dialect/Transform/IR/TransformOps.cpp |  4 +-
 mlir/test/Dialect/MemRef/alloc-to-alloca.mlir | 68 +++++++++++++++++++
 7 files changed, 171 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/MemRef/alloc-to-alloca.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 07c4777d32fd9b0..681759f970cb910 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -51,6 +51,23 @@ def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyAllocToAllocaOp : Op<Transform_Dialect,
+    "apply_patterns.memref.alloc_to_alloca",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface, ["populatePatternsWithState"]>]> {
+  let description = [{
+    Collects patterns to rewrite scoped dynamic allocation (`alloc`/`dealloc`
+    pairs) into automatic allocation (`alloca`) in the same scope, for memrefs
+    of static shape.
+
+    The `size_limit` attribute controls the maximum allocated memory (in bytes,
+    subject to data layout) for which the pattern applies.
+  }];
+
+  let arguments = (ins
+    OptionalAttr<I64Attr>:$size_limit);
+  let assemblyFormat = "(`size_limit` `(` $size_limit^ `)`)? attr-dict";
+}
+
 def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
     "apply_patterns.memref.expand_ops",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index d67d7f1f187ee9d..a918f62cbc8db8f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
 
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 
 namespace mlir {
 class OpBuilder;
@@ -31,6 +32,7 @@ class NarrowTypeEmulationConverter;
 namespace memref {
 class AllocOp;
 class AllocaOp;
+class DeallocOp;
 
 //===----------------------------------------------------------------------===//
 // Patterns
@@ -196,6 +198,15 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
                                           memref::AllocaOp allocaOp,
                                           ValueRange independencies);
 
+/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
+/// the following conditions are met:
+///   - the corresponding dealloc is available in the same block as the alloc;
+///   - the filter, if provided, succeeds on the alloc/dealloc pair.
+/// Otherwise returns nullptr and leaves the IR unchanged.
+memref::AllocaOp allocToAlloca(
+    RewriterBase &rewriter, memref::AllocOp alloc,
+    function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index d40d780e73c5543..40f863378ed680a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -260,6 +260,18 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
       /*name=*/"populatePatterns",
       /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Populate rewrite patterns into the given pattern set taking into account
+        the transform state.
+      }],
+      /*returnType=*/"void",
+      /*name=*/"populatePatternsWithState",
+      /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns,
+                         "::mlir::transform::TransformState &":$state),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ $_op.populatePatterns(patterns); }]
+    >
   ];
 }
 
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 76ab6214dea535c..58f4d8d8f6d21fe 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 
+#include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -64,6 +65,42 @@ StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
+public:
+  explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
+      : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
+        dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
+
+  LogicalResult matchAndRewrite(memref::AllocOp op,
+                                PatternRewriter &rewriter) const override {
+    return success(memref::allocToAlloca(
+        rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
+          MemRefType type = alloc.getMemref().getType();
+          if (!type.hasStaticShape())
+            return false;
+
+          const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
+          int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
+          return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
+        }));
+  }
+
+private:
+  DataLayoutAnalysis dataLayoutAnalysis;
+  int64_t maxSize;
+};
+} // namespace
+
+void transform::ApplyAllocToAllocaOp::populatePatterns(
+    RewritePatternSet &patterns) {}
+
+void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
+    RewritePatternSet &patterns, transform::TransformState &state) {
+  patterns.insert<AllocToAllocaPattern>(
+      state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
+}
+
 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   memref::populateExpandOpsPatterns(patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 5b789da37aa7a22..03765e95b01e7a2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -178,3 +178,27 @@ FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
                                 replacement->getDefiningOp());
   return replacement;
 }
+
+memref::AllocaOp memref::allocToAlloca(
+    RewriterBase &rewriter, memref::AllocOp alloc,
+    function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
+  memref::DeallocOp dealloc = nullptr;
+  for (Operation &candidate :
+       llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
+    dealloc = dyn_cast<memref::DeallocOp>(candidate);
+    if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
+        (!filter || filter(alloc, dealloc))) {
+      break;
+    }
+  }
+
+  if (!dealloc)
+    return nullptr;
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(alloc);
+  auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
+      alloc, alloc.getMemref().getType(), alloc.getOperands());
+  rewriter.eraseOp(dealloc);
+  return alloca;
+}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3c563c0a36c3bb0..7bbbbba4134b184 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -378,8 +378,8 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
   RewritePatternSet patterns(ctx);
   if (!getRegion().empty()) {
     for (Operation &op : getRegion().front()) {
-      cast<transform::PatternDescriptorOpInterface>(&op).populatePatterns(
-          patterns);
+      cast<transform::PatternDescriptorOpInterface>(&op)
+          .populatePatternsWithState(patterns, state);
     }
   }
 
diff --git a/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir b/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir
new file mode 100644
index 000000000000000..2e788236ca91152
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=all %s | FileCheck %s --check-prefixes=CHECK,ALL
+// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=small %s | FileCheck %s --check-prefixes=CHECK,SMALL
+
+func.func private @callee(memref<*xf32>)
+
+// CHECK-LABEL: @large_alloc
+func.func @large_alloc() {
+  // SMALL: memref.alloc()
+  // ALL:   memref.alloca
+  %0 = memref.alloc() : memref<100x100xf32>
+  %1 = memref.cast %0 : memref<100x100xf32> to memref<*xf32>
+  call @callee(%1) : (memref<*xf32>) -> ()
+  // SMALL: memref.dealloc
+  // ALL-NOT: memref.dealloc
+  memref.dealloc %0 : memref<100x100xf32>
+  return
+}
+
+// CHECK-LABEL: @small_alloc
+func.func @small_alloc() {
+  // CHECK: memref.alloca
+  %0 = memref.alloc() : memref<2x2xf32>
+  %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+  call @callee(%1) : (memref<*xf32>) -> ()
+  // CHECK-NOT: memref.dealloc
+  memref.dealloc %0 : memref<2x2xf32>
+  return
+}
+
+// CHECK-LABEL: @no_dealloc
+func.func @no_dealloc() {
+  // CHECK: memref.alloc()
+  %0 = memref.alloc() : memref<2x2xf32>
+  %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+  call @callee(%1) : (memref<*xf32>) -> ()
+  return
+}
+
+// CHECK-LABEL: @mismatching_scope
+func.func @mismatching_scope() {
+  // CHECK: memref.alloc()
+  %0 = memref.alloc() : memref<2x2xf32>
+  %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+  call @callee(%1) : (memref<*xf32>) -> ()
+  scf.execute_region {
+    memref.dealloc %0 : memref<2x2xf32>
+    scf.yield
+  }
+  return
+}
+
+transform.sequence failures(propagate) attributes {transform.target_tag = "all"} {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %0 {
+    transform.apply_patterns.memref.alloc_to_alloca
+  } : !transform.any_op
+  transform.yield
+}
+
+transform.sequence failures(propagate) attributes {transform.target_tag = "small"} {
+^bb0(%arg0: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns to %0 {
+    transform.apply_patterns.memref.alloc_to_alloca size_limit(32)
+  } : !transform.any_op
+  transform.yield
+}



More information about the Mlir-commits mailing list