[Mlir-commits] [mlir] [mlir][memref] Add builder that infers `reinterpret_cast` result type (PR #109432)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 20 07:43:41 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a convenience builder that infers the result type of `memref.reinterpret_cast`.
Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/109432.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (+4-2)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
"OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a ReinterpretCastOp and infer the result type.
+ OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+ "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build a ReinterpretCastOp with static entries.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
// that we can call extract_strided_metadata on it.
if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
memref = builder.create<memref::ReinterpretCastOp>(
- loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
- 0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+ loc, memref,
+ /*offset=*/builder.getIndexAttr(0),
+ /*sizes=*/ArrayRef<OpFoldResult>{},
+ /*strides=*/ArrayRef<OpFoldResult>{});
// Use the `memref.extract_strided_metadata` operation to get the base
// memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
b.getDenseI64ArrayAttr(staticStrides));
}
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ Value source, OpFoldResult offset,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = cast<BaseMemRefType>(source.getType());
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ auto stridedLayout = StridedLayoutAttr::get(
+ b.getContext(), staticOffsets.front(), staticStrides);
+ auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+ stridedLayout, sourceType.getMemorySpace());
+ build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
MemRefType resultType, Value source,
int64_t offset, ArrayRef<int64_t> sizes,
``````````
</details>
https://github.com/llvm/llvm-project/pull/109432
More information about the Mlir-commits
mailing list