[Mlir-commits] [mlir] 7a1579a - [mlir][bufferization] Move one-shot bufferization to Bufferization dialect
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 20 01:21:45 PST 2022
Author: Matthias Springer
Date: 2022-01-20T18:21:20+09:00
New Revision: 7a1579ac67fd8daca4b78a9883e574b41a8f8b69
URL: https://github.com/llvm/llvm-project/commit/7a1579ac67fd8daca4b78a9883e574b41a8f8b69
DIFF: https://github.com/llvm/llvm-project/commit/7a1579ac67fd8daca4b78a9883e574b41a8f8b69.diff
LOG: [mlir][bufferization] Move one-shot bufferization to Bufferization dialect
This commit is the first step towards unifying core bufferization and One-Shot Bufferize.
This commit does not move over the implementations of BufferizableOpInterface yet. This will be done in separate commits. This change does also not move the unit tests yet. The tests will be moved together with op interface implementations and split into separate files.
Differential Revision: https://reviews.llvm.org/D117641
Added:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
similarity index 95%
rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 5abec82efe71..f679a22fa7a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -1,4 +1,4 @@
-//===- BufferizableOpInterface.h - Comprehensive Bufferize ------*- C++ -*-===//
+//===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
#include <utility>
@@ -25,13 +25,11 @@ class BlockAndValueMapping;
class DominanceInfo;
class FuncOp;
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
// TODO: from some HW description.
static constexpr int64_t kBufferAlignments = 128;
-class BufferizationAliasInfo;
class BufferizableOpInterface;
struct BufferizationOptions;
class BufferizationState;
@@ -241,7 +239,8 @@ class BufferizationState {
}
/// Return dialect-specific bufferization state or create one if none exists.
- template <typename StateT> StateT &getOrCreateDialectState(StringRef name) {
+ template <typename StateT>
+ StateT &getOrCreateDialectState(StringRef name) {
// Create state if it does not exist yet.
if (!dialectState.count(name))
dialectState[name] = std::make_unique<StateT>();
@@ -321,15 +320,13 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options);
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
/// AllocationHoistingBarrierOnly is an external implementation of
/// BufferizableOpInterface for ops that are (not yet) bufferizable, but are
@@ -378,8 +375,7 @@ struct AllocationHoistingBarrierOnly
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
};
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
similarity index 98%
rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 6569b0df6812..e070961f9eb8 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -1,4 +1,4 @@
-//===-- BufferizableOpInterface.td - Compreh. Bufferize ----*- tablegen -*-===//
+//===-- BufferizableOpInterface.td - Bufferizable Ops ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,7 +16,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
An op interface for Comprehensive Bufferization. Ops that implement this
interface can be bufferized using Comprehensive Bufferization.
}];
- let cppNamespace = "::mlir::linalg::comprehensive_bufferize";
+ let cppNamespace = "::mlir::bufferization";
let methods = [
InterfaceMethod<
/*desc=*/[{
@@ -311,12 +311,12 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
// TODO: The following two attributes should belong to the tensor dialect.
// The corresponding verifier should also be in the tensor dialect.
/// Attribute name used to mark region arguments that can be bufferized
- /// in-place during linalg comprehensive bufferization.
+ /// in-place during one-shot bufferization.
constexpr const static ::llvm::StringLiteral
kInplaceableAttrName = "linalg.inplaceable";
/// Attribute name used to mark the bufferization layout for region
- /// arguments during linalg comprehensive bufferization.
+ /// arguments during one-shot bufferization.
constexpr const static ::llvm::StringLiteral
kBufferLayoutAttrName = "linalg.buffer_layout";
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
similarity index 61%
rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
index 78de715ea66f..7b903b59f176 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h
@@ -6,22 +6,20 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_
namespace mlir {
class DialectRegistry;
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
namespace bufferization_ext {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace bufferization_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATIONINTERFACEIMPL_H
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 82017e7431f0..8ddfe5a384c0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -1,3 +1,4 @@
add_mlir_dialect(BufferizationOps bufferization)
add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
+add_mlir_interface(BufferizableOpInterface)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 0d3215555c62..c2254dfa3a16 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -52,6 +52,22 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
void populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
+class BufferizationState;
+
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Whether buffer copies are needed or not is queried from `state`.
+///
+/// Note: If `allowUnknownOps` is set to false, bufferization fails when an
+/// unknown op (that does not implement `BufferizableOpInterface`) is found. No
+/// to_tensor/to_memref ops are inserted in that case.
+///
+/// Note: Tje layout map chosen to bufferize is the most dynamic canonical
+/// strided layout of the proper rank. This ensures compatibility with expected
+/// layouts after transformations. Combinations of memref.cast +
+/// canonicalization are responsible for clean ups.
+// TODO: Extract `options` from `state` and pass as separate argument.
+LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
similarity index 88%
rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 468f1d638220..93b8be9c7c7a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -1,4 +1,4 @@
-//===- ComprehensiveBufferize.h - Linalg bufferization pass -----*- C++ -*-===//
+//===- OneShotAnalysis.h - One-Shot (Single Pass) Analysis ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,22 +6,19 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
+#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
namespace mlir {
-
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
class AnalysisBufferizationState;
class BufferizationAliasInfo;
struct AnalysisBufferizationOptions;
-class BufferizationState;
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
@@ -168,7 +165,7 @@ class AnalysisBufferizationState : public BufferizationState {
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
- /// functions and `runComprehensiveBufferize` may access this object.
+ /// functions and `runOneShotBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
};
@@ -176,16 +173,12 @@ class AnalysisBufferizationState : public BufferizationState {
/// `state`.
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
-/// Bufferize `op` and its nested ops. Bufferization decisions are stored in
-/// `state`.
-LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
-
-/// Run Comprehensive Bufferize on the given op: Analysis + Bufferization
-LogicalResult runComprehensiveBufferize(
- Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options);
+/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
+LogicalResult
+runOneShotBufferize(Operation *op,
+ std::unique_ptr<AnalysisBufferizationOptions> options);
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
+#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index c2a58e98061b..2022f1459c10 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -1,5 +1,2 @@
-set(LLVM_TARGET_DEFINITIONS BufferizableOpInterface.td)
-mlir_tablegen(BufferizableOpInterface.h.inc -gen-op-interface-decls)
-mlir_tablegen(BufferizableOpInterface.cpp.inc -gen-op-interface-defs)
-add_public_tablegen_target(MLIRBufferizableOpInterfaceIncGen)
-add_dependencies(mlir-headers MLIRBufferizableOpInterfaceIncGen)
+# no targets defined here
+
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 8614c9d50acf..05f2257b972d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -9,20 +9,16 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
namespace mlir {
-
class DialectRegistry;
namespace linalg {
namespace comprehensive_bufferize {
-
-class BufferizationAliasInfo;
-
namespace linalg_ext {
-struct InitTensorEliminationStep : public PostAnalysisStep {
+struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
/// A function that matches anchor OpOperands for InitTensorOp elimination.
using AnchorMatchFn = std::function<bool(OpOperand &)>;
@@ -39,11 +35,11 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// InitTensorOp.
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
- LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- AnchorMatchFn anchorMatchFunc,
- RewriteFn rewriteFunc,
- SmallVector<Operation *> &newOps);
+ LogicalResult
+ eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
+ AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+ SmallVector<Operation *> &newOps);
};
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
@@ -51,8 +47,8 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// (and some other conditions are met).
struct InsertSliceAnchoredInitTensorEliminationStep
: public InitTensorEliminationStep {
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
+ LogicalResult run(Operation *op, bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index 6b4039f5283f..194465e29c4c 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -17,16 +17,19 @@ class DialectRegistry;
struct LogicalResult;
class ModuleOp;
+namespace bufferization {
+struct AnalysisBufferizationOptions;
+} // namespace bufferization
+
namespace linalg {
namespace comprehensive_bufferize {
-struct AnalysisBufferizationOptions;
-
/// Run Module Bufferization on the given module. Performs a simple function
/// call analysis to determine which function arguments are inplaceable. Then
-/// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization.
+/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize.
LogicalResult runComprehensiveBufferize(
- ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options);
+ ModuleOp moduleOp,
+ std::unique_ptr<bufferization::AnalysisBufferizationOptions> options);
namespace std_ext {
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index a2ba910aeac9..afea3fae490d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -9,7 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
namespace mlir {
@@ -23,9 +23,9 @@ namespace scf_ext {
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
/// currently assumed to alias with the i-th iter_arg (in the absence of
/// conflicts).
-struct AssertScfForAliasingProperties : public PostAnalysisStep {
- LogicalResult run(Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
+struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
+ LogicalResult run(Operation *op, bufferization::BufferizationState &state,
+ bufferization::BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) override;
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
similarity index 84%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
rename to mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 57bae783f9d7..fb081d3d6c3c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -1,4 +1,4 @@
-//===- BufferizableOpInterface.cpp - Comprehensive Bufferize --------------===//
+//===- BufferizableOpInterface.cpp - Bufferizable Ops ---=----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
@@ -19,13 +18,11 @@
#include "llvm/Support/Debug.h"
namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
#define DEBUG_TYPE "bufferizable-op-interface"
@@ -33,7 +30,7 @@ namespace comprehensive_bufferize {
#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
using namespace mlir;
-using namespace linalg::comprehensive_bufferize;
+using namespace bufferization;
//===----------------------------------------------------------------------===//
// BufferizationOptions
@@ -42,15 +39,15 @@ using namespace linalg::comprehensive_bufferize;
// Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions() {}
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
- BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
+BufferizableOpInterface
+BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (isOpAllowed(op))
return dyn_cast<BufferizableOpInterface>(op);
return nullptr;
}
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
- BufferizationOptions::dynCastBufferizableOp(Value value) const {
+BufferizableOpInterface
+BufferizationOptions::dynCastBufferizableOp(Value value) const {
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
if (isOpAllowed(bufferizableOp.getOperation()))
return bufferizableOp;
@@ -72,8 +69,7 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
-mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
- OpResult result) const {
+BufferizationState::getAliasingOpOperand(OpResult result) const {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
@@ -82,9 +78,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
/// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult
-mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
- OpOperand &opOperand) const {
+OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
@@ -93,8 +87,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
/// op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToMemoryRead(OpOperand &opOperand) const {
+bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
@@ -106,8 +99,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Return true if `opOperand` bufferizes to a memory write. Return
/// `true` if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToMemoryWrite(OpOperand &opOperand) const {
+bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
@@ -119,8 +111,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Return true if `opOperand` does neither read nor write but bufferizes to an
/// alias. Return false if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToAliasOnly(OpOperand &opOperand) const {
+bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
@@ -133,8 +124,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Return true if the given value is read by an op that bufferizes to a memory
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
- Value value) const {
+bool BufferizationState::isValueRead(Value value) const {
assert(value.getType().isa<TensorType>() && "expected TensorType");
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
@@ -157,9 +147,8 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further.
-llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
- BufferizationState::findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition) const {
+llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain(
+ Value value, llvm::function_ref<bool(Value)> condition) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -185,8 +174,8 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
}
// Find the Values of the last preceding write of a given Value.
-llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
- BufferizationState::findLastPrecedingWrite(Value value) const {
+llvm::SetVector<Value>
+BufferizationState::findLastPrecedingWrite(Value value) const {
return findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
@@ -198,8 +187,7 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
});
}
-mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
- const BufferizationOptions &options)
+BufferizationState::BufferizationState(const BufferizationOptions &options)
: options(options) {}
// bufferization.to_memref is not allowed to change the rank.
@@ -237,8 +225,7 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) {
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-FailureOr<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
+FailureOr<Value> BufferizationState::getBuffer(
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
Optional<Operation *> customCopyInsertionPoint) const {
OpBuilder::InsertionGuard guard(rewriter);
@@ -294,8 +281,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
return resultBuffer;
}
-void mlir::linalg::comprehensive_bufferize::replaceOpWithBufferizedValues(
- RewriterBase &rewriter, Operation *op, ValueRange values) {
+void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
+ Operation *op,
+ ValueRange values) {
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
@@ -409,9 +397,10 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
- OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref,
- const BufferizationOptions &options) {
+FailureOr<Value>
+bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+ bool deallocMemref,
+ const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -445,9 +434,10 @@ FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
}
/// Create a memref allocation.
-FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
- OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape,
- const BufferizationOptions &options) {
+FailureOr<Value>
+bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ArrayRef<Value> dynShape,
+ const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape);
@@ -458,9 +448,9 @@ FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
}
/// Create a memref deallocation.
-LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc(
- OpBuilder &b, Location loc, Value allocatedBuffer,
- const BufferizationOptions &options) {
+LogicalResult
+bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+ const BufferizationOptions &options) {
if (options.deallocationFn)
return (*options.deallocationFn)(b, loc, allocatedBuffer);
@@ -470,9 +460,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc(
}
/// Create a memory copy between two memref buffers.
-LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy(
- OpBuilder &b, Location loc, Value from, Value to,
- const BufferizationOptions &options) {
+LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
+ Value from, Value to,
+ const BufferizationOptions &options) {
if (options.memCpyFn)
return (*options.memCpyFn)(b, loc, from, to);
@@ -484,27 +474,28 @@ LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy(
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
-bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
+bool bufferization::isFunctionArgument(Value value) {
auto bbArg = value.dyn_cast<BlockArgument>();
if (!bbArg)
return false;
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
}
-MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
- ShapedType shapedType, MemRefLayoutAttrInterface layout,
- Attribute memorySpace) {
+MemRefType
+bufferization::getContiguousMemRefType(ShapedType shapedType,
+ MemRefLayoutAttrInterface layout,
+ Attribute memorySpace) {
return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
layout, memorySpace);
}
-UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType(
- Type elementType, Attribute memorySpace) {
+UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType,
+ Attribute memorySpace) {
return UnrankedMemRefType::get(elementType, memorySpace);
}
-MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
- RankedTensorType tensorType, unsigned addressSpace) {
+MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType,
+ unsigned addressSpace) {
// TODO: address space decisions to connect with the actual alloc.
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
similarity index 84%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
rename to mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
index fd3632fb56d0..835a153eb854 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp
@@ -6,24 +6,21 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
-using namespace linalg;
-using namespace comprehensive_bufferize;
+using namespace mlir::bufferization;
namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
+namespace bufferization {
namespace bufferization_ext {
-// TODO: These ops should implement BufferizableOpInterface directly when moved
-// to the Bufferization dialect.
+// TODO: These ops should implement BufferizableOpInterface.
/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
/// to x. Other to_memref ops are ignored during bufferization.
@@ -57,7 +54,6 @@ struct ToMemrefOpInterface
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
// ToMemrefOps always bufferize inplace.
- // TODO: Remove ToMemrefOps from the analysis.
return true;
}
@@ -121,14 +117,11 @@ struct ToTensorOpInterface
};
} // namespace bufferization_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace bufferization
} // namespace mlir
-void mlir::linalg::comprehensive_bufferize::bufferization_ext::
- registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
- registry.addOpInterface<bufferization::ToMemrefOp,
- bufferization_ext::ToMemrefOpInterface>();
- registry.addOpInterface<bufferization::ToTensorOp,
- bufferization_ext::ToTensorOpInterface>();
+void bufferization_ext::registerBufferizableOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addOpInterface<ToMemrefOp, bufferization_ext::ToMemrefOpInterface>();
+ registry.addOpInterface<ToTensorOp, bufferization_ext::ToTensorOpInterface>();
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index dd54c1abdbc4..cdb6656f0f0a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRBufferization
+ PARTIAL_SOURCES_INTENDED
AllocationOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
@@ -16,3 +17,17 @@ add_mlir_dialect_library(MLIRBufferization
MLIRTensor
MLIRMemRef
)
+
+add_mlir_dialect_library(MLIRBufferizableOpInterface
+ PARTIAL_SOURCES_INTENDED
+ BufferizableOpInterface.cpp
+ BufferizationInterfaceImpl.cpp
+
+ DEPENDS
+ MLIRBufferizableOpInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRBufferization
+ MLIRMemRef
+)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index cebff6a7a860..0bacc21d4688 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -8,10 +8,12 @@
#include "PassDetail.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::bufferization;
@@ -129,3 +131,80 @@ std::unique_ptr<OperationPass<FuncOp>>
mlir::bufferization::createFinalizingBufferizePass() {
return std::make_unique<FinalizingBufferizePass>();
}
+
+static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+
+/// Return true if the given op has a tensor result or a tensor operand.
+static bool hasTensorSemantics(Operation *op) {
+ bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
+ bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
+ return hasTensorResult || hasTensorOperand;
+}
+
+/// Rewrite pattern that bufferizes bufferizable ops.
+struct BufferizationPattern
+ : public OpInterfaceRewritePattern<BufferizableOpInterface> {
+ BufferizationPattern(MLIRContext *context, const BufferizationState &state,
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
+ state(state) {}
+
+ LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
+ PatternRewriter &rewriter) const override {
+ // No tensors => no buffers.
+ if (!hasTensorSemantics(bufferizableOp.getOperation()))
+ return failure();
+ if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
+ return failure();
+ return bufferizableOp.bufferize(rewriter, state);
+ }
+
+private:
+ const BufferizationState &state;
+};
+
+/// Check the result of bufferization. Return an error if an op was not
+/// bufferized, unless partial bufferization is allowed.
+static LogicalResult
+checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
+ if (!options.allowUnknownOps) {
+ // Check if all ops were bufferized.
+ LogicalResult status = success();
+ op->walk([&](Operation *op) {
+ if (!hasTensorSemantics(op))
+ return WalkResult::advance();
+
+ // Bufferization dialect ops will canonicalize away if all other ops are
+ // bufferized.
+ if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
+ return WalkResult::advance();
+
+ // Ops that are not in the allow list can be ignored.
+ if (!options.isOpAllowed(op))
+ return WalkResult::advance();
+
+ // Ops without any uses and no side effects will fold away.
+ if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
+ return WalkResult::advance();
+
+ status = op->emitError("op was not bufferized");
+ return WalkResult::interrupt();
+ });
+
+ if (failed(status))
+ return status;
+ }
+
+ return success();
+}
+
+LogicalResult bufferization::bufferizeOp(Operation *op,
+ const BufferizationState &state) {
+ // Bufferize the op and its nested ops.
+ OwningRewritePatternList patterns(op->getContext());
+ patterns.add<BufferizationPattern>(op->getContext(), state);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return failure();
+
+ return checkBufferizationResult(op, state.getOptions());
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index cab164f1ab05..b3f4fb38d003 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRBufferizationTransforms
Bufferize.cpp
BufferDeallocation.cpp
+ OneShotAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
@@ -9,7 +10,13 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
MLIRBufferizationPassIncGen
LINK_LIBS PUBLIC
+ MLIRBufferizableOpInterface
MLIRBufferization
+ MLIRControlFlowInterfaces
+ MLIRInferTypeOpInterface
+ MLIRIR
+ MLIRMemRef
MLIRPass
+ MLIRStandard
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
similarity index 88%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index ae9532e25dc3..c21f7f9704b5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1,4 +1,4 @@
-//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
+//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
//
-// Comprehensive Bufferize bufferizes function bodies. Function boundaries
-// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
-// ModuleBufferization.cpp is an extension of Comprehensive Bufferize for simple
+// One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp
+// bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
+// ModuleBufferization.cpp is an extension of One-Shot Analysis for simple
// call graphs.
//
-// Comprehensive Bufferize consists of two phases.
+// One-Shot Bufferize consists of two phases.
//
// 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
// inserting buffer copies. The analysis queries op bufferization semantics
@@ -20,49 +20,43 @@
// function does not generate buffer copies for OpResults that were decided
// to bufferize inplace during the analysis phase.
//
+// This file contains only the analysis. The actual bufferization is implemented
+// via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a
+// helper function `runOneShotBufferize` that analyzes an op (and its nested
+// ops) and then bufferizes it.
+//
// Inplace bufferization decisions are passed from the analysis to the
// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`.
// They can be printed for debugging purposes with `testAnalysisOnly`.
//
// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
-// treated conservatively. E.g., the analysis has to assume that their
+// treated conservatively. E.g., the analysis has to assume that their tensor
// OpOperands bufferize to memory writes. While such ops can be analyzed, they
// are not bufferized and remain in the IR. to_tensor and to_memref ops are
// inserted at the bufferization boundary.
//
-// Note: If `allowUnknownOps` is set to false, bufferization fails when an
-// unknown op (that does not implement `BufferizableOpInterface`) is found. No
-// to_tensor/to_memref ops are inserted.
-//
-// This pass caters to high-performance codegen where buffer reuse is deemed
-// critical: the pass should fail if the bufferized form of the function needs
-// to return any buffer, unless `allowReturnMemref` is enabled.
-//
-// Lastly, note that layout map chosen to bufferize is the most dynamic
-// canonical strided layout of the proper rank. This ensures compatibility with
-// expected layouts after transformations. Combinations of memref.cast +
-// canonicalization are responsible for clean ups.
+// This analysis caters to high-performance codegen where buffer reuse is deemed
+// critical: the analysis should fail if the bufferized form of the function
+// needs to return a buffer, unless `allowReturnMemref` is enabled.
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include <random>
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
-using namespace linalg;
-using namespace tensor;
-using namespace comprehensive_bufferize;
+using namespace mlir::bufferization;
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
@@ -751,65 +745,8 @@ struct AssertDestinationPassingStyle : public PostAnalysisStep {
}
};
-/// Rewrite pattern that bufferizes bufferizable ops.
-struct BufferizationPattern
- : public OpInterfaceRewritePattern<BufferizableOpInterface> {
- BufferizationPattern(MLIRContext *context, const BufferizationState &state,
- PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
- state(state) {}
-
- LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
- PatternRewriter &rewriter) const override {
- // No tensors => no buffers.
- if (!hasTensorSemantics(bufferizableOp.getOperation()))
- return failure();
- if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
- return failure();
- return bufferizableOp.bufferize(rewriter, state);
- }
-
-private:
- const BufferizationState &state;
-};
-
-/// Check the result of bufferization. Return an error if an op was not
-/// bufferized, unless partial bufferization is allowed.
-static LogicalResult
-checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
- if (!options.allowUnknownOps) {
- // Check if all ops were bufferized.
- LogicalResult status = success();
- op->walk([&](Operation *op) {
- if (!hasTensorSemantics(op))
- return WalkResult::advance();
-
- // Bufferization dialect ops will canonicalize away if all other ops are
- // bufferized.
- if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
- return WalkResult::advance();
-
- // Ops that are not in the allow list can be ignored.
- if (!options.isOpAllowed(op))
- return WalkResult::advance();
-
- // Ops without any uses and no side effects will fold away.
- if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
- return WalkResult::advance();
-
- status = op->emitError("op was not bufferized");
- return WalkResult::interrupt();
- });
-
- if (failed(status))
- return status;
- }
-
- return success();
-}
-
-LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
- Operation *op, AnalysisBufferizationState &state) {
+LogicalResult bufferization::analyzeOp(Operation *op,
+ AnalysisBufferizationState &state) {
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
const auto &options =
@@ -849,18 +786,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
return success();
}
-LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
- Operation *op, const BufferizationState &state) {
- // Bufferize the op and its nested ops.
- OwningRewritePatternList patterns(op->getContext());
- patterns.add<BufferizationPattern>(op->getContext(), state);
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
- return failure();
-
- return checkBufferizationResult(op, state.getOptions());
-}
-
-LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
+LogicalResult bufferization::runOneShotBufferize(
Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
AnalysisBufferizationState state(op, *options);
if (failed(analyzeOp(op, state)))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
index bed69f19582f..c1fad16126e4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
@@ -9,7 +9,9 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+
+using namespace mlir::bufferization;
void mlir::linalg::comprehensive_bufferize::affine_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 3c0926e3fae6..256916c5e7b3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -9,12 +9,14 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/BufferUtils.h"
+using namespace mlir::bufferization;
+
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index a912d2378dc2..f3601f3c3935 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -1,9 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp
ArithInterfaceImpl.cpp
- BufferizableOpInterface.cpp
- BufferizationInterfaceImpl.cpp
- ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
SCFInterfaceImpl.cpp
@@ -12,18 +9,6 @@ set(LLVM_OPTIONAL_SOURCES
VectorInterfaceImpl.cpp
)
-add_mlir_dialect_library(MLIRBufferizableOpInterface
- BufferizableOpInterface.cpp
-
- DEPENDS
- MLIRBufferizableOpInterfaceIncGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRBufferization
- MLIRMemRef
-)
-
add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
AffineInterfaceImpl.cpp
@@ -48,7 +33,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
- MLIRComprehensiveBufferize
+ MLIRBufferizationTransforms
MLIRIR
MLIRLinalg
MLIRTensor
@@ -59,7 +44,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
- MLIRComprehensiveBufferize
+ MLIRBufferizationTransforms
MLIRIR
MLIRSCF
)
@@ -91,18 +76,14 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
MLIRVector
)
-add_mlir_dialect_library(MLIRComprehensiveBufferize
- BufferizationInterfaceImpl.cpp
- ComprehensiveBufferize.cpp
+add_mlir_dialect_library(MLIRModuleBufferization
ModuleBufferization.cpp
LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
- MLIRControlFlowInterfaces
- MLIRInferTypeOpInterface
+ MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRStandard
MLIRStandardOpsTransforms
- MLIRTransforms
)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 6cce30d165a3..b01100cc9e08 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
@@ -17,6 +17,7 @@
using namespace mlir;
using namespace linalg;
using namespace comprehensive_bufferize;
+using namespace mlir::bufferization;
namespace {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 3c13e0a748f2..f4a2a5d69215 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -71,9 +71,10 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
@@ -82,6 +83,7 @@ using namespace mlir;
using namespace linalg;
using namespace tensor;
using namespace comprehensive_bufferize;
+using namespace mlir::bufferization;
namespace {
/// The state of analysis of a FuncOp.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 97cba18592c9..87dd5b09773d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -7,13 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
+using namespace mlir::bufferization;
+
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
index 5e603b376576..7941c979b09e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
@@ -8,11 +8,13 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+using namespace mlir::bufferization;
+
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index dc0742b99022..f6748985dde1 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -7,13 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
+using namespace mlir::bufferization;
namespace mlir {
namespace linalg {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 4dd7fdcebc03..24205478c4c0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -7,11 +7,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+using namespace mlir::bufferization;
+
namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index 98fdccbdeb7d..ccab0bbd6cb2 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -65,12 +65,12 @@ constexpr const ::llvm::StringLiteral
/// Attribute name used to mark the bufferization layout for region
/// arguments during linalg comprehensive bufferization.
constexpr const ::llvm::StringLiteral
- comprehensive_bufferize::BufferizableOpInterface::kBufferLayoutAttrName;
+ bufferization::BufferizableOpInterface::kBufferLayoutAttrName;
/// Attribute name used to mark region arguments that can be bufferized
/// in-place during linalg comprehensive bufferization.
constexpr const ::llvm::StringLiteral
- comprehensive_bufferize::BufferizableOpInterface::kInplaceableAttrName;
+ bufferization::BufferizableOpInterface::kInplaceableAttrName;
/// Trait to check if T provides a `regionBuilder` method.
template <typename T, typename... Args>
@@ -125,7 +125,7 @@ void mlir::linalg::LinalgDialect::initialize() {
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
- using comprehensive_bufferize::BufferizableOpInterface;
+ using bufferization::BufferizableOpInterface;
if (attr.getName() == BufferizableOpInterface::kInplaceableAttrName) {
if (!attr.getValue().isa<BoolAttr>()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 118a9436609b..bea51784b14c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRArithmetic
MLIRBufferizableOpInterface
MLIRComplex
- MLIRComprehensiveBufferize
MLIRInferTypeOpInterface
MLIRIR
MLIRMemRef
@@ -46,6 +45,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRLinalgAnalysis
MLIRLinalgBufferizableOpInterfaceImpl
MLIRLinalgUtils
+ MLIRModuleBufferization
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRSCFTransforms
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f8233c5bfe11..5ea77641b7f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -8,12 +8,12 @@
#include "PassDetail.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
@@ -27,6 +27,7 @@
#include "mlir/Transforms/Passes.h"
using namespace mlir;
+using namespace mlir::bufferization;
using namespace mlir::linalg;
using namespace mlir::linalg::comprehensive_bufferize;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8284dd97b3d2..8f7fb89afba4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -13,8 +13,8 @@
#include "CodegenUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -284,8 +284,7 @@ static bool isInPlace(Value val) {
if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp()))
if (auto attr = funcOp.getArgAttrOfType<BoolAttr>(
arg.getArgNumber(),
- linalg::comprehensive_bufferize::BufferizableOpInterface::
- kInplaceableAttrName))
+ bufferization::BufferizableOpInterface::kInplaceableAttrName))
return attr.getValue();
return false;
}
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 7c9ad470eacf..10657c8c514f 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -17,7 +17,7 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic
MLIRBufferizableOpInterface
- MLIRComprehensiveBufferize
+ MLIRBufferizationTransforms
MLIRGPUTransforms
MLIRLinalg
MLIRLinalgBufferizableOpInterfaceImpl
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 87677ef2383d..1ba0b891692b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -12,14 +12,13 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
@@ -34,6 +33,7 @@
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::linalg::comprehensive_bufferize;
+using namespace mlir::bufferization;
namespace {
/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are
@@ -118,8 +118,8 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
options->dialectFilter->insert(dialectNamespace);
}
- Operation *op = getOperation().getOperation();
- if (failed(runComprehensiveBufferize(op, std::move(options))))
+ Operation *op = getOperation();
+ if (failed(runOneShotBufferize(op, std::move(options))))
return;
if (testAnalysisOnly)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9b67e9a819d1..cd4d2e195d03 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6506,7 +6506,7 @@ gentbl_cc_library(
td_library(
name = "BufferizableOpInterfaceTdFiles",
srcs = [
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td",
+ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td",
],
includes = ["include"],
deps = [
@@ -6520,15 +6520,15 @@ gentbl_cc_library(
tbl_outs = [
(
["-gen-op-interface-decls"],
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc",
+ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc",
+ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td",
+ td_file = "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td",
deps = [
":BufferizableOpInterfaceTdFiles",
],
@@ -6537,10 +6537,12 @@ gentbl_cc_library(
cc_library(
name = "BufferizableOpInterface",
srcs = [
- "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp",
+ "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
+ "lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp",
],
hdrs = [
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h",
+ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
+ "include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h",
],
includes = ["include"],
deps = [
@@ -6601,7 +6603,7 @@ cc_library(
deps = [
":BufferizableOpInterface",
":BufferizationDialect",
- ":ComprehensiveBufferize",
+ ":BufferizationTransforms",
":IR",
":LinalgOps",
":LinalgStructuredOpsIncGen",
@@ -6621,7 +6623,7 @@ cc_library(
deps = [
":BufferizableOpInterface",
":BufferizationDialect",
- ":ComprehensiveBufferize",
+ ":BufferizationTransforms",
":IR",
":SCFDialect",
":Support",
@@ -6891,7 +6893,6 @@ cc_library(
":BufferizationDialect",
":BufferizationTransforms",
":ComplexDialect",
- ":ComprehensiveBufferize",
":DialectUtils",
":IR",
":InferTypeOpInterface",
@@ -6901,6 +6902,7 @@ cc_library(
":LinalgStructuredOpsIncGen",
":MathDialect",
":MemRefDialect",
+ ":ModuleBufferization",
":Pass",
":SCFBufferizableOpInterfaceImpl",
":SCFDialect",
@@ -6921,30 +6923,23 @@ cc_library(
)
cc_library(
- name = "ComprehensiveBufferize",
+ name = "ModuleBufferization",
srcs = [
- "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp",
- "lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
"lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
],
hdrs = [
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h",
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
],
includes = ["include"],
deps = [
":BufferizableOpInterface",
":BufferizationDialect",
- ":ControlFlowInterfaces",
+ ":BufferizationTransforms",
":DialectUtils",
":IR",
- ":InferTypeOpInterface",
":MemRefDialect",
- ":Pass",
":StandardOps",
":Support",
- ":Transforms",
"//llvm:Support",
],
)
@@ -7957,12 +7952,10 @@ gentbl_cc_library(
cc_library(
name = "BufferizationDialect",
- srcs = glob(
- [
- "lib/Dialect/Bufferization/IR/Bufferization*.h",
- "lib/Dialect/Bufferization/IR/Bufferization*.cpp",
- ],
- ),
+ srcs = [
+ "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
+ "lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+ ],
hdrs = ["include/mlir/Dialect/Bufferization/IR/Bufferization.h"],
includes = ["include"],
deps = [
@@ -8011,11 +8004,17 @@ cc_library(
deps = [
":AllocationOpInterface",
":Analysis",
+ ":BufferizableOpInterface",
":BufferizationDialect",
":BufferizationPassIncGen",
+ ":ControlFlowInterfaces",
+ ":DialectUtils",
":IR",
+ ":InferTypeOpInterface",
":MemRefDialect",
":Pass",
+ ":StandardOps",
+ ":Support",
":Transforms",
"//llvm:Support",
],
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index bb5de09f9687..c494b994a4f3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -390,7 +390,7 @@ cc_library(
"//mlir:ArithmeticDialect",
"//mlir:BufferizableOpInterface",
"//mlir:BufferizationDialect",
- "//mlir:ComprehensiveBufferize",
+ "//mlir:BufferizationTransforms",
"//mlir:GPUDialect",
"//mlir:IR",
"//mlir:LinalgBufferizableOpInterfaceImpl",
More information about the Mlir-commits
mailing list