[Mlir-commits] [mlir] e448c79 - [mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect
Matthias Springer
llvmlistbot at llvm.org
Sun Jan 30 05:12:34 PST 2022
Author: Matthias Springer
Date: 2022-01-30T22:12:14+09:00
New Revision: e448c793c66521ee48d0107c33b80a2ff1baaaaf
URL: https://github.com/llvm/llvm-project/commit/e448c793c66521ee48d0107c33b80a2ff1baaaaf
DIFF: https://github.com/llvm/llvm-project/commit/e448c793c66521ee48d0107c33b80a2ff1baaaaf.diff
LOG: [mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect
Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted.
Differential Revision: https://reviews.llvm.org/D118559
Added:
mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
Modified:
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
deleted file mode 100644
index ae3b3db23e648..0000000000000
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace std_ext {
-
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
-
-} // namespace std_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000000..a85acbbb195a6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,18 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace mlir
+
+#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 52bbea000d1f3..d6b8d2028e0e2 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -23,10 +23,6 @@ class BufferizeTypeConverter;
class RewritePatternSet;
-void populateStdBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns);
-
/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass();
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 6bd83938346e4..3e08865c6f71a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
let summary = "Bufferize the std dialect";
let constructor = "mlir::createStdBufferizePass()";
- let dependentDialects = ["bufferization::BufferizationDialect",
- "memref::MemRefDialect", "scf::SCFDialect"];
}
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 10d0f72ebcfe8..5733f88c953ab 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor
)
-add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
- StdInterfaceImpl.cpp
-
- LINK_LIBS PUBLIC
- MLIRBufferization
- MLIRStandard
-)
-
add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
VectorInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5a42474d49da4..d1418c40f035c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCF
MLIRSCFTransforms
MLIRSCFUtils
- MLIRStdBufferizableOpInterfaceImpl
MLIRPass
MLIRStandard
MLIRStandardOpsTransforms
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 9d2e6a539a182..f809bf35dc6fe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -15,10 +15,10 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry);
- std_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
similarity index 83%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
rename to mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
index 7941c979b09e2..b89a5372a48b6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,19 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
+using namespace mlir;
using namespace mlir::bufferization;
namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace std_ext {
+namespace {
/// Bufferization of std.select. Just replace the operands.
struct SelectOpInterface
@@ -69,12 +68,10 @@ struct SelectOpInterface
}
};
-} // namespace std_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace
} // namespace mlir
-void mlir::linalg::comprehensive_bufferize::std_ext::
- registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
- registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
+void mlir::registerBufferizableOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addOpInterface<SelectOp, SelectOpInterface>();
}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 9851784d4951c..64f9d040a71ca 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -12,64 +12,34 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
-
-namespace {
-class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(SelectOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (!op.getCondition().getType().isa<IntegerType>())
- return rewriter.notifyMatchFailure(op, "requires scalar condition");
-
- rewriter.replaceOpWithNewOp<SelectOp>(op, adaptor.getCondition(),
- adaptor.getTrueValue(),
- adaptor.getFalseValue());
- return success();
- }
-};
-} // namespace
-
-void mlir::populateStdBufferizePatterns(
- bufferization::BufferizeTypeConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
-}
+using namespace mlir::bufferization;
namespace {
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnOperation() override {
- auto *context = &getContext();
- bufferization::BufferizeTypeConverter typeConverter;
- RewritePatternSet patterns(context);
- ConversionTarget target(*context);
-
- target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
- memref::MemRefDialect>();
+ std::unique_ptr<BufferizationOptions> options =
+ getPartialBufferizationOptions();
+ options->addToDialectFilter<StandardOpsDialect>();
- populateStdBufferizePatterns(typeConverter, patterns);
- // We only bufferize the case of tensor selected type and scalar condition,
- // as that boils down to a select over memref descriptors (don't need to
- // touch the data).
- target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
- return typeConverter.isLegal(op.getType()) ||
- !op.getCondition().getType().isa<IntegerType>();
- });
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(bufferizeOp(getOperation(), *options)))
signalPassFailure();
}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+ StandardOpsDialect, scf::SCFDialect>();
+ mlir::registerBufferizableOpInterfaceExternalModels(registry);
+ }
};
} // namespace
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index d5869ce207cf8..7db425fdc361d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
+ BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
DecomposeCallGraphTypes.cpp
FuncBufferize.cpp
@@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
LINK_LIBS PUBLIC
MLIRAffine
MLIRArithmeticTransforms
+ MLIRBufferization
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 7ad34948bcebd..4fa6070921806 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses
MLIRPass
MLIRSCF
MLIRSCFTransforms
- MLIRStdBufferizableOpInterfaceImpl
MLIRStandard
+ MLIRStandardOpsTransforms
MLIRTensor
MLIRTensorTransforms
MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 5cde7cf2ac094..b074043b7a3ab 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -18,12 +18,12 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/PassManager.h"
@@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
- std_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ mlir::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 185998a0bf947..e786e82d4f5e6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6658,24 +6658,6 @@ cc_library(
],
)
-cc_library(
- name = "StdBufferizableOpInterfaceImpl",
- srcs = [
- "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
- ],
- hdrs = [
- "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
- ],
- includes = ["include"],
- deps = [
- ":BufferizationDialect",
- ":IR",
- ":StandardOps",
- ":Support",
- "//llvm:Support",
- ],
-)
-
cc_library(
name = "VectorBufferizableOpInterfaceImpl",
srcs = [
@@ -6916,7 +6898,6 @@ cc_library(
":SCFUtils",
":StandardOps",
":StandardOpsTransforms",
- ":StdBufferizableOpInterfaceImpl",
":Support",
":TensorDialect",
":TensorTransforms",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index d23ca654fd09b..4292c258c051c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -403,7 +403,7 @@ cc_library(
"//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",
- "//mlir:StdBufferizableOpInterfaceImpl",
+ "//mlir:StandardOpsTransforms",
"//mlir:TensorDialect",
"//mlir:TensorTransforms",
"//mlir:TransformUtils",
More information about the Mlir-commits
mailing list