[Mlir-commits] [mlir] b661788 - [mlir] NFC - Expose GlobalCreator so it can be reused.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 16 05:29:36 PDT 2021


Author: Nicolas Vasilache
Date: 2021-03-16T12:29:04Z
New Revision: b661788b77e570dc82fe2f89a355713c144407f1

URL: https://github.com/llvm/llvm-project/commit/b661788b77e570dc82fe2f89a355713c144407f1
DIFF: https://github.com/llvm/llvm-project/commit/b661788b77e570dc82fe2f89a355713c144407f1.diff

LOG: [mlir] NFC - Expose GlobalCreator so it can be reused.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/BufferUtils.h
    mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h
index 70da6a025343..33edffa372a3 100644
--- a/mlir/include/mlir/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Transforms/BufferUtils.h
@@ -120,6 +120,24 @@ class BufferPlacementTransformationBase {
   Liveness liveness;
 };
 
+namespace memref {
+class GlobalOp;
+} // namespace memref
+
+// Support class to create global ops for tensor-valued constants in the
+// program. Globals are created lazily at the top of the `moduleOp` with pretty
+// names. Duplicates are avoided.
+class GlobalCreator {
+public:
+  explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
+  memref::GlobalOp getGlobalFor(ConstantOp constantOp);
+
+private:
+  ModuleOp moduleOp;
+  // This could use memref::GlobalOp key but we avoid introducing a new
+  // dependence to the memref dialect for this.
+  DenseMap<Attribute, Operation *> globals;
+};
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_BUFFERUTILS_H

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 18c3be94685b..55d34059e033 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -15,64 +15,47 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/BufferUtils.h"
 #include "mlir/Transforms/Bufferize.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
-namespace {
-// This class creates global ops for all tensor-valued constants in the program.
-// It creates them with pretty names and makes sure that duplicate globals
-// aren't created.
-class GlobalCreator {
-public:
-  explicit GlobalCreator(ModuleOp module);
-  memref::GlobalOp getGlobalFor(Attribute attr) {
-    assert(globals.find(attr) != globals.end() && "unknown constant attr");
-    return globals[attr];
-  }
-
-private:
-  DenseMap<Attribute, memref::GlobalOp> globals;
-};
+memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
+  auto type = constantOp.getType().cast<RankedTensorType>();
 
-GlobalCreator::GlobalCreator(ModuleOp module) {
   BufferizeTypeConverter typeConverter;
+
+  // If we already have a global for this constant value, no need to do
+  // anything else.
+  auto it = globals.find(constantOp.getValue());
+  if (it != globals.end())
+    return cast<memref::GlobalOp>(it->second);
+
   // Create a builder without an insertion point. We will insert using the
   // symbol table to guarantee unique names.
-  OpBuilder globalBuilder(module.getContext());
-  SymbolTable symbolTable(module);
-  module.walk([&](ConstantOp op) {
-    // We only want tensor constants for now.
-    auto type = op.getType().dyn_cast<RankedTensorType>();
-    if (!type)
-      return;
-    // If we already have a global for this constant value, no need to do
-    // anything else.
-    auto it = globals.find(op.getValue());
-    if (it != globals.end())
-      return;
+  OpBuilder globalBuilder(moduleOp.getContext());
+  SymbolTable symbolTable(moduleOp);
 
-    // Create a pretty name.
-    SmallString<64> buf;
-    llvm::raw_svector_ostream os(buf);
-    interleave(type.getShape(), os, "x");
-    os << "x" << type.getElementType();
+  // Create a pretty name.
+  SmallString<64> buf;
+  llvm::raw_svector_ostream os(buf);
+  interleave(type.getShape(), os, "x");
+  os << "x" << type.getElementType();
 
-    auto global = globalBuilder.create<memref::GlobalOp>(
-        op.getLoc(), (Twine("__constant_") + os.str()).str(),
-        /*sym_visibility=*/globalBuilder.getStringAttr("private"),
-        /*type=*/typeConverter.convertType(type),
-        /*initial_value=*/op.getValue().cast<ElementsAttr>(),
-        /*constant=*/true);
-    symbolTable.insert(global);
-    // The symbol table inserts at the end of the module, but globals are a bit
-    // nicer if they are at the beginning.
-    global->moveBefore(&module.front());
-    globals[op.getValue()] = global;
-  });
+  auto global = globalBuilder.create<memref::GlobalOp>(
+      constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
+      /*sym_visibility=*/globalBuilder.getStringAttr("private"),
+      /*type=*/typeConverter.convertType(type),
+      /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
+      /*constant=*/true);
+  symbolTable.insert(global);
+  // The symbol table inserts at the end of the module, but globals are a bit
+  // nicer if they are at the beginning.
+  global->moveBefore(&moduleOp.front());
+  globals[constantOp.getValue()] = global;
+  return global;
 }
-} // namespace
 
 namespace {
 class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
@@ -89,7 +72,7 @@ class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
     if (!type)
       return failure();
 
-    auto globalMemref = globals.getGlobalFor(op.value());
+    auto globalMemref = globals.getGlobalFor(op);
     rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, globalMemref.type(),
                                                      globalMemref.getName());
     return success();


        


More information about the Mlir-commits mailing list