[Mlir-commits] [mlir] b661788 - [mlir] NFC - Expose GlobalCreator so it can be reused.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Mar 16 05:29:36 PDT 2021
Author: Nicolas Vasilache
Date: 2021-03-16T12:29:04Z
New Revision: b661788b77e570dc82fe2f89a355713c144407f1
URL: https://github.com/llvm/llvm-project/commit/b661788b77e570dc82fe2f89a355713c144407f1
DIFF: https://github.com/llvm/llvm-project/commit/b661788b77e570dc82fe2f89a355713c144407f1.diff
LOG: [mlir] NFC - Expose GlobalCreator so it can be reused.
Added:
Modified:
mlir/include/mlir/Transforms/BufferUtils.h
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h
index 70da6a025343..33edffa372a3 100644
--- a/mlir/include/mlir/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Transforms/BufferUtils.h
@@ -120,6 +120,24 @@ class BufferPlacementTransformationBase {
Liveness liveness;
};
+namespace memref {
+class GlobalOp;
+} // namespace memref
+
+// Support class to create global ops for tensor-valued constants in the
+// program. Globals are created lazily at the top of the `moduleOp` with pretty
+// names. Duplicates are avoided.
+class GlobalCreator {
+public:
+ explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
+ memref::GlobalOp getGlobalFor(ConstantOp constantOp);
+
+private:
+ ModuleOp moduleOp;
+ // This could use memref::GlobalOp key but we avoid introducing a new
+ // dependence to the memref dialect for this.
+ DenseMap<Attribute, Operation *> globals;
+};
} // end namespace mlir
#endif // MLIR_TRANSFORMS_BUFFERUTILS_H
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 18c3be94685b..55d34059e033 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -15,64 +15,47 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
-namespace {
-// This class creates global ops for all tensor-valued constants in the program.
-// It creates them with pretty names and makes sure that duplicate globals
-// aren't created.
-class GlobalCreator {
-public:
- explicit GlobalCreator(ModuleOp module);
- memref::GlobalOp getGlobalFor(Attribute attr) {
- assert(globals.find(attr) != globals.end() && "unknown constant attr");
- return globals[attr];
- }
-
-private:
- DenseMap<Attribute, memref::GlobalOp> globals;
-};
+memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
+ auto type = constantOp.getType().cast<RankedTensorType>();
-GlobalCreator::GlobalCreator(ModuleOp module) {
BufferizeTypeConverter typeConverter;
+
+ // If we already have a global for this constant value, no need to do
+ // anything else.
+ auto it = globals.find(constantOp.getValue());
+ if (it != globals.end())
+ return cast<memref::GlobalOp>(it->second);
+
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
- OpBuilder globalBuilder(module.getContext());
- SymbolTable symbolTable(module);
- module.walk([&](ConstantOp op) {
- // We only want tensor constants for now.
- auto type = op.getType().dyn_cast<RankedTensorType>();
- if (!type)
- return;
- // If we already have a global for this constant value, no need to do
- // anything else.
- auto it = globals.find(op.getValue());
- if (it != globals.end())
- return;
+ OpBuilder globalBuilder(moduleOp.getContext());
+ SymbolTable symbolTable(moduleOp);
- // Create a pretty name.
- SmallString<64> buf;
- llvm::raw_svector_ostream os(buf);
- interleave(type.getShape(), os, "x");
- os << "x" << type.getElementType();
+ // Create a pretty name.
+ SmallString<64> buf;
+ llvm::raw_svector_ostream os(buf);
+ interleave(type.getShape(), os, "x");
+ os << "x" << type.getElementType();
- auto global = globalBuilder.create<memref::GlobalOp>(
- op.getLoc(), (Twine("__constant_") + os.str()).str(),
- /*sym_visibility=*/globalBuilder.getStringAttr("private"),
- /*type=*/typeConverter.convertType(type),
- /*initial_value=*/op.getValue().cast<ElementsAttr>(),
- /*constant=*/true);
- symbolTable.insert(global);
- // The symbol table inserts at the end of the module, but globals are a bit
- // nicer if they are at the beginning.
- global->moveBefore(&module.front());
- globals[op.getValue()] = global;
- });
+ auto global = globalBuilder.create<memref::GlobalOp>(
+ constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
+ /*sym_visibility=*/globalBuilder.getStringAttr("private"),
+ /*type=*/typeConverter.convertType(type),
+ /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
+ /*constant=*/true);
+ symbolTable.insert(global);
+ // The symbol table inserts at the end of the module, but globals are a bit
+ // nicer if they are at the beginning.
+ global->moveBefore(&moduleOp.front());
+ globals[constantOp.getValue()] = global;
+ return global;
}
-} // namespace
namespace {
class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
@@ -89,7 +72,7 @@ class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
if (!type)
return failure();
- auto globalMemref = globals.getGlobalFor(op.value());
+ auto globalMemref = globals.getGlobalFor(op);
rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, globalMemref.type(),
globalMemref.getName());
return success();
More information about the Mlir-commits
mailing list