[Mlir-commits] [mlir] 8bb5ca5 - [mlir] Support bufferization of arith.constant to memref.global with memory space
Maya Amrami
llvmlistbot at llvm.org
Mon Mar 20 08:13:48 PDT 2023
Author: Maya Amrami
Date: 2023-03-20T17:09:51+02:00
New Revision: 8bb5ca58327ec5d0788b1546844b06b1118c5cb5
URL: https://github.com/llvm/llvm-project/commit/8bb5ca58327ec5d0788b1546844b06b1118c5cb5
DIFF: https://github.com/llvm/llvm-project/commit/8bb5ca58327ec5d0788b1546844b06b1118c5cb5.diff
LOG: [mlir] Support bufferization of arith.constant to memref.global with memory space
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D146381
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 6c521acd0e146..85e9c47ad5302 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -125,7 +125,8 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
- uint64_t alignment);
+ uint64_t alignment,
+ Attribute memorySpace = {});
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 8408aad6e3fc1..9602d530cf826 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -26,10 +27,11 @@ struct ConstantOpInterface
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- // TODO: Implement memory space for this op. E.g., by adding a memory_space
- // attribute to ConstantOp.
- if (options.defaultMemorySpace != Attribute())
- return op->emitError("memory space not implemented yet");
+ Attribute memorySpace;
+ if (options.defaultMemorySpace.has_value())
+ memorySpace = *options.defaultMemorySpace;
+ else
+ return constantOp->emitError("could not infer memory space");
// Only ranked tensors are supported.
if (!constantOp.getType().isa<RankedTensorType>())
@@ -43,7 +45,7 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, options.bufferAlignment);
+ getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 38d69194be1e8..b9776e2fb2095 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -147,7 +147,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+ Attribute memorySpace) {
auto type = constantOp.getType().cast<RankedTensorType>();
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -184,10 +185,13 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
: IntegerAttr();
BufferizeTypeConverter typeConverter;
+ auto memrefType = typeConverter.convertType(type).cast<MemRefType>();
+ if (memorySpace)
+ memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
- /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
+ /*type=*/memrefType,
/*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
/*constant=*/true,
/*alignment=*/memrefAlignment);
diff --git a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
index 315da00a00d78..deda8bb74b323 100644
--- a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
+++ b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
@@ -13,8 +13,8 @@ func.func @inconsistent_memory_space_arith_select(%c: i1) -> tensor<10xf32> {
// -----
-func.func @constant_memory_space(%idx: index, %v: i32) -> tensor<3xi32> {
- // expected-error @+2 {{memory space not implemented yet}}
+func.func @unknown_memory_space(%idx: index, %v: i32) -> tensor<3xi32> {
+ // expected-error @+2 {{could not infer memory space}}
// expected-error @+1 {{failed to bufferize op}}
%cst = arith.constant dense<[5, 1000, 20]> : tensor<3xi32>
%0 = tensor.insert %v into %cst[%idx] : tensor<3xi32>
More information about the Mlir-commits
mailing list