[Mlir-commits] [mlir] Revert "[mlir] Remove dialect specific bufferization passes" (PR #93528)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 03:21:59 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->93488
Buildbot failure: https://lab.llvm.org/buildbot/#/builders/220/builds/39911
---
Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93528.diff
32 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+16)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+5)
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.h (+4)
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+10)
- (modified) mlir/include/mlir/Dialect/Shape/Transforms/Passes.h (+7)
- (modified) mlir/include/mlir/Dialect/Shape/Transforms/Passes.td (+7)
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td (+5)
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+5)
- (added) mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp (+67)
- (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+23)
- (added) mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp (+52)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp (+49)
- (modified) mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp (+58)
- (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp (+55)
- (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
- (modified) mlir/test/Dialect/Arith/bufferize.mlir (+6-2)
- (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+29-1)
- (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+2-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+2-1)
- (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+2-1)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+1-1)
- (modified) mlir/test/Dialect/Vector/bufferize-invalid.mlir (+2-1)
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 9dc262cc72ed0..cbc6147cb81e2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -24,6 +24,9 @@ namespace arith {
class WideIntEmulationConverter;
class NarrowTypeEmulationConverter;
+/// Create a pass to bufferize arith.constant ops.
+std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+
/// Adds patterns to emulate wide Arith and Function ops over integer
/// types into supported ones. This is done by splitting original power-of-two
/// i2N integer types into two iN halves.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 550c5c0cf4f60..4096e309199e9 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -11,6 +11,22 @@
include "mlir/Pass/PassBase.td"
+def ArithBufferizePass : Pass<"arith-bufferize", "ModuleOp"> {
+ let summary = "Bufferize Arith dialect ops.";
+ let description = [{
+ This pass bufferizes arith dialect ops.
+
+ This pass needs to be a module pass because it inserts memref.global
+ ops into the module, which cannot be done safely from a function pass due to
+ multi-threading. Most other bufferization passes can run in parallel at
+ function granularity.
+ }];
+ let options = [
+ Option<"alignment", "alignment", "unsigned", /*default=*/"0",
+ "Create global memrefs with a specified alignment">,
+ ];
+}
+
def ArithExpandOpsPass : Pass<"arith-expand"> {
let summary = "Legalize Arith ops to be convertible to LLVM.";
let dependentDialects = ["vector::VectorDialect"];
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..459c252b70712 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -221,6 +221,9 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// insert_slice ops.
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
+/// Create a pass that bufferizes ops from the bufferization dialect.
+std::unique_ptr<Pass> createBufferizationBufferizePass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 8f8826b9ad56b..75ce85c9128c9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -350,6 +350,11 @@ def FinalizingBufferize : Pass<"finalizing-bufferize", "func::FuncOp"> {
let constructor = "mlir::bufferization::createFinalizingBufferizePass()";
}
+def BufferizationBufferize : Pass<"bufferization-bufferize", "func::FuncOp"> {
+ let summary = "Bufferize the `bufferization` dialect";
+ let constructor = "mlir::bufferization::createBufferizationBufferizePass()";
+}
+
def DropEquivalentBufferResults : Pass<"drop-equivalent-buffer-results", "ModuleOp"> {
let summary = "Remove MemRef return values that are equivalent to a bbArg";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index f2955d55e59ec..d36d1e70f0b14 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -22,6 +22,10 @@ namespace func {
class FuncOp;
} // namespace func
+namespace bufferization {
+struct OneShotBufferizationOptions;
+} // namespace bufferization
+
#define GEN_PASS_DECL
#include "mlir/Dialect/Linalg/Passes.h.inc" // IWYU pragma: keep
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 0621a9f33ba1e..0a4ce8953136d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -89,6 +89,16 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
];
}
+def LinalgBufferizePass : Pass<"linalg-bufferize"> {
+ let summary = "Bufferize the linalg dialect";
+ let dependentDialects = [
+ "affine::AffineDialect",
+ "bufferization::BufferizationDialect",
+ "linalg::LinalgDialect",
+ "memref::MemRefDialect",
+ ];
+}
+
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let summary = "Convert named ops into generic ops";
let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 28e17459ff962..cfb637f133f54 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -47,6 +47,13 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
+// Bufferizes shape dialect ops.
+//
+// Note that most shape dialect ops must be converted to std before
+// bufferization happens, as they are intended to be bufferized at the std
+// level.
+std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
+
/// Outline the shape computation part by adding shape.func and populate
/// conrresponding mapping infomation into ShapeMappingAnalysis.
std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 83834509b4a35..9dfda9ea33615 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -103,4 +103,11 @@ def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
let constructor = "mlir::createShapeToShapeLowering()";
}
+// TODO: Generalize this to allow any type conversions desired.
+def ShapeBufferize : Pass<"shape-bufferize", "func::FuncOp"> {
+ let summary = "Bufferize the shape dialect.";
+ let constructor = "mlir::createShapeBufferizePass()";
+ let dependentDialects = ["bufferization::BufferizationDialect",
+ "memref::MemRefDialect"];
+}
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index 964c35b3f15b8..48f9066934a25 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -21,6 +21,9 @@ namespace tensor {
/// Creates an instance of the `tensor` subset folding pass.
std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
+/// Creates an instance of the `tensor` dialect bufferization pass.
+std::unique_ptr<Pass> createTensorBufferizePass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index be4c333836ec0..4cc3844f29120 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -27,4 +27,9 @@ def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
];
}
+def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
+ let summary = "Bufferize the `tensor` dialect";
+ let constructor = "mlir::tensor::createTensorBufferizePass()";
+}
+
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..911402551e14d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -17,6 +17,9 @@ namespace vector {
#define GEN_PASS_DECL
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+/// Creates an instance of the `vector` dialect bufferization pass.
+std::unique_ptr<Pass> createVectorBufferizePass();
+
/// Creates an instance of the `vector.mask` lowering pass.
std::unique_ptr<Pass> createLowerVectorMaskPass();
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..31a0b3b2f0c53 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -11,6 +11,11 @@
include "mlir/Pass/PassBase.td"
+def VectorBufferize : Pass<"vector-bufferize", "func::FuncOp"> {
+ let summary = "Bufferize Vector dialect ops";
+ let constructor = "mlir::vector::createVectorBufferizePass()";
+}
+
def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
let summary = "Lower 'vector.mask' operations";
let constructor = "mlir::vector::createLowerVectorMaskPass()";
diff --git a/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..9a066756f429c
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/Bufferize.cpp
@@ -0,0 +1,67 @@
+//===- Bufferize.cpp - Bufferization for Arith ops ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+namespace mlir {
+namespace arith {
+#define GEN_PASS_DEF_ARITHBUFFERIZEPASS
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace arith
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+/// Pass to bufferize Arith ops.
+struct ArithBufferizePass
+ : public arith::impl::ArithBufferizePassBase<ArithBufferizePass> {
+ using ArithBufferizePassBase::ArithBufferizePassBase;
+
+ ArithBufferizePass(uint64_t alignment = 0, bool constantOpOnly = false)
+ : constantOpOnly(constantOpOnly) {
+ this->alignment = alignment;
+ }
+
+ void runOnOperation() override {
+ BufferizationOptions options = getPartialBufferizationOptions();
+ if (constantOpOnly) {
+ options.opFilter.allowOperation<arith::ConstantOp>();
+ } else {
+ options.opFilter.allowDialect<arith::ArithDialect>();
+ }
+ options.bufferAlignment = alignment;
+
+ if (failed(bufferizeOp(getOperation(), options)))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ arith::ArithDialect>();
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+
+private:
+ bool constantOpOnly;
+};
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::arith::createConstantBufferizePass(uint64_t alignment) {
+ return std::make_unique<ArithBufferizePass>(alignment,
+ /*constantOpOnly=*/true);
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6b8bde8dc2aaf..12659eaba1fa5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
+ Bufferize.cpp
BufferViewFlowOpInterfaceImpl.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 0fddd60eb8140..7ba347a1f15e4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -320,6 +320,29 @@ struct OneShotBufferizePass
};
} // namespace
+namespace {
+struct BufferizationBufferizePass
+ : public bufferization::impl::BufferizationBufferizeBase<
+ BufferizationBufferizePass> {
+ void runOnOperation() override {
+ BufferizationOptions options = getPartialBufferizationOptions();
+ options.opFilter.allowDialect<BufferizationDialect>();
+
+ if (failed(bufferizeOp(getOperation(), options)))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
+ return std::make_unique<BufferizationBufferizePass>();
+}
+
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
return std::make_unique<OneShotBufferizePass>();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..8812ca14ba610
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -0,0 +1,52 @@
+//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGBUFFERIZEPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+/// Converts Linalg operations that work on tensor-type operands or results to
+/// work on buffers.
+struct LinalgBufferizePass
+ : public impl::LinalgBufferizePassBase<LinalgBufferizePass> {
+ using impl::LinalgBufferizePassBase<
+ LinalgBufferizePass>::LinalgBufferizePassBase;
+ void runOnOperation() override {
+ BufferizationOptions options = getPartialBufferizationOptions();
+ options.opFilter.allowDialect<linalg::LinalgDialect>();
+
+ if (failed(bufferizeOp(getOperation(), options)))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ tensor::TensorDialect, linalg::LinalgDialect>();
+ linalg::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..ed9f40089282a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
AllInterfaces.cpp
BubbleUpExtractSlice.cpp
BufferizableOpInterfaceImpl.cpp
+ Bufferize.cpp
ConstantFold.cpp
ConvertToDestinationStyle.cpp
ConvertConv2DToImg2Col.cpp
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..9dadbdbc91eca
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -0,0 +1,49 @@
+//====----- Bufferize.cpp - Bufferization of shape ops ---------*- C++-*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SHAPEBUFFERIZE
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace bufferization;
+
+namespace {
+struct ShapeBufferizePass
+ : public impl::ShapeBufferizeBase<ShapeBufferizePass> {
+ void runOnOperation() override {
+ BufferizationOptions options = getPartialBufferizationOptions();
+ options.opFilter.allowDialect<shape::ShapeDialect>();
+
+ if (failed(bufferizeOp(getOperation(), options)))
+ signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ shape::ShapeDialect>();
+ shape::registerBufferizableOpInterfaceExternalModels(registry);
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>> mlir::createShapeBufferizePass() {
+ return std::make_unique<ShapeBufferizePass>();
+}
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index a51c6780c2866..7c9b0d2e5e3a8 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
BufferizableOpInterfaceImpl.cpp
+ Bufferize.cpp
OutlineShapeComputation.cpp
RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..d27c4576a8b7a
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -0,0 +1,58 @@
+//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of `tensor` dialect ops
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/93528
More information about the Mlir-commits
mailing list