[Mlir-commits] [mlir] e55e36d - [mlir] alloc-to-alloca conversion for memref (#65335)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 5 08:58:26 PDT 2023
Author: Oleksandr "Alex" Zinenko
Date: 2023-09-05T17:58:22+02:00
New Revision: e55e36de7a10e731f76c29ef5dd5287e323eb98c
URL: https://github.com/llvm/llvm-project/commit/e55e36de7a10e731f76c29ef5dd5287e323eb98c
DIFF: https://github.com/llvm/llvm-project/commit/e55e36de7a10e731f76c29ef5dd5287e323eb98c.diff
LOG: [mlir] alloc-to-alloca conversion for memref (#65335)
Introduce a simple conversion of a memref.alloc/dealloc pair into an
alloca in the same scope. Expose it as a transform op and a pattern.
Allocas typically lower to stack allocations as opposed to alloc/dealloc
that lower to significantly more expensive malloc/free calls. In
addition, this can be combined with allocation hoisting from loops to
further improve performance.
Added:
mlir/test/Dialect/MemRef/alloc-to-alloca.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 07c4777d32fd9b0..681759f970cb910 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -51,6 +51,23 @@ def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyAllocToAllocaOp : Op<Transform_Dialect,
+ "apply_patterns.memref.alloc_to_alloca",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface, ["populatePatternsWithState"]>]> {
+ let description = [{
+ Collects patterns to rewrite scoped dynamic allocation (`alloc`/`dealloc`
+ pairs) into automatic allocation (`alloca`) in the same scope, for memrefs
+ of static shape.
+
+ The `size_limit` attribute controls the maximum allocated memory (in bytes,
+ subject to data layout) for which the pattern applies.
+ }];
+
+ let arguments = (ins
+ OptionalAttr<I64Attr>:$size_limit);
+ let assemblyFormat = "(`size_limit` `(` $size_limit^ `)`)? attr-dict";
+}
+
def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.memref.expand_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index d67d7f1f187ee9d..a918f62cbc8db8f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -15,6 +15,7 @@
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
namespace mlir {
class OpBuilder;
@@ -31,6 +32,7 @@ class NarrowTypeEmulationConverter;
namespace memref {
class AllocOp;
class AllocaOp;
+class DeallocOp;
//===----------------------------------------------------------------------===//
// Patterns
@@ -196,6 +198,15 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocaOp,
ValueRange independencies);
+/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
+/// the following conditions are met:
+/// - the corresponding dealloc is available in the same block as the alloc;
+/// - the filter, if provided, succeeds on the alloc/dealloc pair.
+/// Otherwise returns nullptr and leaves the IR unchanged.
+memref::AllocaOp allocToAlloca(
+ RewriterBase &rewriter, memref::AllocOp alloc,
+ function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index d40d780e73c5543..40f863378ed680a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -260,6 +260,18 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
/*name=*/"populatePatterns",
/*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns)
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate rewrite patterns into the given pattern set taking into account
+ the transform state.
+ }],
+ /*returnType=*/"void",
+ /*name=*/"populatePatternsWithState",
+ /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns,
+ "::mlir::transform::TransformState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ $_op.populatePatterns(patterns); }]
+ >
];
}
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 76ab6214dea535c..58f4d8d8f6d21fe 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -64,6 +65,42 @@ StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
+namespace {
+class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
+public:
+ explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
+ : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
+ dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
+
+ LogicalResult matchAndRewrite(memref::AllocOp op,
+ PatternRewriter &rewriter) const override {
+ return success(memref::allocToAlloca(
+ rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
+ MemRefType type = alloc.getMemref().getType();
+ if (!type.hasStaticShape())
+ return false;
+
+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
+ int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
+ return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
+ }));
+ }
+
+private:
+ DataLayoutAnalysis dataLayoutAnalysis;
+ int64_t maxSize;
+};
+} // namespace
+
+void transform::ApplyAllocToAllocaOp::populatePatterns(
+ RewritePatternSet &patterns) {}
+
+void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
+ RewritePatternSet &patterns, transform::TransformState &state) {
+ patterns.insert<AllocToAllocaPattern>(
+ state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
+}
+
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExpandOpsPatterns(patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 5b789da37aa7a22..03765e95b01e7a2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -178,3 +178,27 @@ FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
replacement->getDefiningOp());
return replacement;
}
+
+memref::AllocaOp memref::allocToAlloca(
+ RewriterBase &rewriter, memref::AllocOp alloc,
+ function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
+ memref::DeallocOp dealloc = nullptr;
+ for (Operation &candidate :
+ llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
+ dealloc = dyn_cast<memref::DeallocOp>(candidate);
+ if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
+ (!filter || filter(alloc, dealloc))) {
+ break;
+ }
+ }
+
+ if (!dealloc)
+ return nullptr;
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(alloc);
+ auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
+ alloc, alloc.getMemref().getType(), alloc.getOperands());
+ rewriter.eraseOp(dealloc);
+ return alloca;
+}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3c563c0a36c3bb0..7bbbbba4134b184 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -378,8 +378,8 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
RewritePatternSet patterns(ctx);
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
- cast<transform::PatternDescriptorOpInterface>(&op).populatePatterns(
- patterns);
+ cast<transform::PatternDescriptorOpInterface>(&op)
+ .populatePatternsWithState(patterns, state);
}
}
diff --git a/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir b/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir
new file mode 100644
index 000000000000000..2e788236ca91152
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/alloc-to-alloca.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=all %s | FileCheck %s --check-prefixes=CHECK,ALL
+// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=small %s | FileCheck %s --check-prefixes=CHECK,SMALL
+
+func.func private @callee(memref<*xf32>)
+
+// CHECK-LABEL: @large_alloc
+func.func @large_alloc() {
+ // SMALL: memref.alloc()
+ // ALL: memref.alloca
+ %0 = memref.alloc() : memref<100x100xf32>
+ %1 = memref.cast %0 : memref<100x100xf32> to memref<*xf32>
+ call @callee(%1) : (memref<*xf32>) -> ()
+ // SMALL: memref.dealloc
+ // ALL-NOT: memref.dealloc
+ memref.dealloc %0 : memref<100x100xf32>
+ return
+}
+
+// CHECK-LABEL: @small_alloc
+func.func @small_alloc() {
+ // CHECK: memref.alloca
+ %0 = memref.alloc() : memref<2x2xf32>
+ %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+ call @callee(%1) : (memref<*xf32>) -> ()
+ // CHECK-NOT: memref.dealloc
+ memref.dealloc %0 : memref<2x2xf32>
+ return
+}
+
+// CHECK-LABEL: @no_dealloc
+func.func @no_dealloc() {
+ // CHECK: memref.alloc()
+ %0 = memref.alloc() : memref<2x2xf32>
+ %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+ call @callee(%1) : (memref<*xf32>) -> ()
+ return
+}
+
+// CHECK-LABEL: @mismatching_scope
+func.func @mismatching_scope() {
+ // CHECK: memref.alloc()
+ %0 = memref.alloc() : memref<2x2xf32>
+ %1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
+ call @callee(%1) : (memref<*xf32>) -> ()
+ scf.execute_region {
+ memref.dealloc %0 : memref<2x2xf32>
+ scf.yield
+ }
+ return
+}
+
+transform.sequence failures(propagate) attributes {transform.target_tag = "all"} {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.memref.alloc_to_alloca
+ } : !transform.any_op
+ transform.yield
+}
+
+transform.sequence failures(propagate) attributes {transform.target_tag = "small"} {
+^bb0(%arg0: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %0 {
+ transform.apply_patterns.memref.alloc_to_alloca size_limit(32)
+ } : !transform.any_op
+ transform.yield
+}
More information about the Mlir-commits
mailing list