[Mlir-commits] [mlir] e2a37bb - [mlir] Add alignment option to constant tensor bufferization pass

Eugene Zhulenev llvmlistbot at llvm.org
Fri Oct 8 03:17:26 PDT 2021


Author: Eugene Zhulenev
Date: 2021-10-08T03:17:20-07:00
New Revision: e2a37bb5407e6ccf465bd870e89505768497ca50

URL: https://github.com/llvm/llvm-project/commit/e2a37bb5407e6ccf465bd870e89505768497ca50
DIFF: https://github.com/llvm/llvm-project/commit/e2a37bb5407e6ccf465bd870e89505768497ca50.diff

LOG: [mlir] Add alignment option to constant tensor bufferization pass

Reviewed By: bkramer

Differential Revision: https://reviews.llvm.org/D111364

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/include/mlir/Transforms/BufferUtils.h
    mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
    mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index c7e331e856519..58eb7f0b7cb6f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
     RewritePatternSet &patterns);
 
 /// Creates an instance of tensor constant bufferization pass.
-std::unique_ptr<Pass> createTensorConstantBufferizePass();
+std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
 
 /// Creates an instance of the StdExpand pass that legalizes Std
 /// dialect ops to be convertible to LLVM. For example,

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 15f63e952691c..286b685591778 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
   }];
   let constructor = "mlir::createTensorConstantBufferizePass()";
   let dependentDialects = ["memref::MemRefDialect"];
+  let options = [
+    Option<"alignment", "alignment", "unsigned", /*default=*/"0",
+           "Create global memrefs with a specified alignment">,
+  ];
 }
 
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h
index f73c97d75a1e8..5d1b9e5a7d577 100644
--- a/mlir/include/mlir/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Transforms/BufferUtils.h
@@ -125,11 +125,13 @@ class GlobalOp;
 // names. Duplicates are avoided.
 class GlobalCreator {
 public:
-  explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
+  GlobalCreator(ModuleOp module, unsigned alignment = 0)
+      : moduleOp(module), alignment(alignment) {}
   memref::GlobalOp getGlobalFor(ConstantOp constantOp);
 
 private:
   ModuleOp moduleOp;
+  unsigned alignment;
   // This could use memref::GlobalOp key but we avoid introducing a new
   // dependence to the memref dialect for this.
   DenseMap<Attribute, Operation *> globals;

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 39dbd8ea49506..df61b1d2f6324 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
   interleave(type.getShape(), os, "x");
   os << "x" << type.getElementType();
 
+  // Add an optional alignment to the global memref.
+  IntegerAttr memrefAlignment =
+      alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
+                    : IntegerAttr();
+
   auto global = globalBuilder.create<memref::GlobalOp>(
       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
       /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
       /*constant=*/true,
-      /*alignment=*/IntegerAttr());
+      /*alignment=*/memrefAlignment);
   symbolTable.insert(global);
   // The symbol table inserts at the end of the module, but globals are a bit
   // nicer if they are at the beginning.
@@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
 }
 
 namespace {
-struct TensorConstantBufferizePass
+class TensorConstantBufferizePass
     : public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
+public:
+  explicit TensorConstantBufferizePass(unsigned alignment) {
+    if (alignment)
+      this->alignment = alignment;
+  }
+
   void runOnOperation() override {
     auto module = getOperation();
-    GlobalCreator globals(module);
+    GlobalCreator globals(module, alignment);
 
     auto *context = &getContext();
     BufferizeTypeConverter typeConverter;
@@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
-  return std::make_unique<TensorConstantBufferizePass>();
+std::unique_ptr<Pass>
+mlir::createTensorConstantBufferizePass(unsigned alignment) {
+  return std::make_unique<TensorConstantBufferizePass>(alignment);
 }

diff  --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
index 16a7d7d6d002e..cdaccf3684955 100644
--- a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
@@ -1,9 +1,17 @@
 // RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
 
 // CHECK-LABEL: module {
+
 // We check the debug name too since we put some effort into making that readable.
 // The name isn't load-bearing though.
+
 // CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// CHECK-NOT: alignment
+
+// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// ALIGNED-SAME: {alignment = 64 : i64}
+
 // CHECK: @basic
 func @basic() -> tensor<3x4xf32> {
   // CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>


        


More information about the Mlir-commits mailing list