[Mlir-commits] [mlir] e2a37bb - [mlir] Add alignment option to constant tensor bufferization pass
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Oct 8 03:17:26 PDT 2021
Author: Eugene Zhulenev
Date: 2021-10-08T03:17:20-07:00
New Revision: e2a37bb5407e6ccf465bd870e89505768497ca50
URL: https://github.com/llvm/llvm-project/commit/e2a37bb5407e6ccf465bd870e89505768497ca50
DIFF: https://github.com/llvm/llvm-project/commit/e2a37bb5407e6ccf465bd870e89505768497ca50.diff
LOG: [mlir] Add alignment option to constant tensor bufferization pass
Reviewed By: bkramer
Differential Revision: https://reviews.llvm.org/D111364
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
mlir/include/mlir/Transforms/BufferUtils.h
mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index c7e331e856519..58eb7f0b7cb6f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -39,7 +39,7 @@ void populateTensorConstantBufferizePatterns(
RewritePatternSet &patterns);
/// Creates an instance of tensor constant bufferization pass.
-std::unique_ptr<Pass> createTensorConstantBufferizePass();
+std::unique_ptr<Pass> createTensorConstantBufferizePass(unsigned alignment = 0);
/// Creates an instance of the StdExpand pass that legalizes Std
/// dialect ops to be convertible to LLVM. For example,
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 15f63e952691c..286b685591778 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -62,6 +62,10 @@ def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
}];
let constructor = "mlir::createTensorConstantBufferizePass()";
let dependentDialects = ["memref::MemRefDialect"];
+ let options = [
+ Option<"alignment", "alignment", "unsigned", /*default=*/"0",
+ "Create global memrefs with a specified alignment">,
+ ];
}
#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h
index f73c97d75a1e8..5d1b9e5a7d577 100644
--- a/mlir/include/mlir/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Transforms/BufferUtils.h
@@ -125,11 +125,13 @@ class GlobalOp;
// names. Duplicates are avoided.
class GlobalCreator {
public:
- explicit GlobalCreator(ModuleOp module) : moduleOp(module) {}
+ GlobalCreator(ModuleOp module, unsigned alignment = 0)
+ : moduleOp(module), alignment(alignment) {}
memref::GlobalOp getGlobalFor(ConstantOp constantOp);
private:
ModuleOp moduleOp;
+ unsigned alignment;
// This could use memref::GlobalOp key but we avoid introducing a new
// dependence to the memref dialect for this.
DenseMap<Attribute, Operation *> globals;
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 39dbd8ea49506..df61b1d2f6324 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -43,13 +43,18 @@ memref::GlobalOp GlobalCreator::getGlobalFor(ConstantOp constantOp) {
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
+ // Add an optional alignment to the global memref.
+ IntegerAttr memrefAlignment =
+ alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
+ : IntegerAttr();
+
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/typeConverter.convertType(type).cast<MemRefType>(),
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
/*constant=*/true,
- /*alignment=*/IntegerAttr());
+ /*alignment=*/memrefAlignment);
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
@@ -90,11 +95,17 @@ void mlir::populateTensorConstantBufferizePatterns(
}
namespace {
-struct TensorConstantBufferizePass
+class TensorConstantBufferizePass
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
+public:
+ explicit TensorConstantBufferizePass(unsigned alignment) {
+ if (alignment)
+ this->alignment = alignment;
+ }
+
void runOnOperation() override {
auto module = getOperation();
- GlobalCreator globals(module);
+ GlobalCreator globals(module, alignment);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
@@ -111,6 +122,7 @@ struct TensorConstantBufferizePass
};
} // namespace
-std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
- return std::make_unique<TensorConstantBufferizePass>();
+std::unique_ptr<Pass>
+mlir::createTensorConstantBufferizePass(unsigned alignment) {
+ return std::make_unique<TensorConstantBufferizePass>(alignment);
}
diff --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
index 16a7d7d6d002e..cdaccf3684955 100644
--- a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
@@ -1,9 +1,17 @@
// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s
// CHECK-LABEL: module {
+
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
+
// CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// CHECK-NOT: alignment
+
+// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// ALIGNED-SAME: {alignment = 64 : i64}
+
// CHECK: @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>
More information about the Mlir-commits
mailing list