[Mlir-commits] [mlir] 0259f92 - [mlir][memref] Add builder that infers `reinterpret_cast` result type (#109432)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 25 00:33:19 PDT 2024
Author: Matthias Springer
Date: 2024-09-25T09:33:15+02:00
New Revision: 0259f92711599c45d229fb12f6f51915fffac6bd
URL: https://github.com/llvm/llvm-project/commit/0259f92711599c45d229fb12f6f51915fffac6bd
DIFF: https://github.com/llvm/llvm-project/commit/0259f92711599c45d229fb12f6f51915fffac6bd.diff
LOG: [mlir][memref] Add builder that infers `reinterpret_cast` result type (#109432)
Add a convenience builder that infers the result type of
`memref.reinterpret_cast`.
Note: It is not possible to remove the result type from all builder
overloads because this op currently also allows certain
operand/attribute + result type combinations that do not match. The op
verifier should probably be made stricter, but that's a larger change
that requires additional `memref.cast` ops in some places that build
`reinterpret_cast` ops.
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a ReinterpretCastOp and infer the result type.
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+ "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ReinterpretCastOp with static entries.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
memref = builder.create<memref::ReinterpretCastOp>(
- loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
- 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+ loc, memref,
+ /*offset=*/builder.getIndexAttr(0),
+ /*sizes=*/ArrayRef<OpFoldResult>{},
+ /*strides=*/ArrayRef<OpFoldResult>{});
// Use the `memref.extract_strided_metadata` operation to get the base
// memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
b.getDenseI64ArrayAttr(staticStrides));
}
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ Value source, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ auto stridedLayout = StridedLayoutAttr::get(
+ b.getContext(), staticOffsets.front(), staticStrides);
+ auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+ stridedLayout, sourceType.getMemorySpace());
+ build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
More information about the Mlir-commits
mailing list