[Mlir-commits] [mlir] [mlir][memref] Add builder that infers `reinterpret_cast` result type (PR #109432)

Matthias Springer llvmlistbot at llvm.org
Fri Sep 20 07:43:07 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/109432

Add a convenience builder that infers the result type of `memref.reinterpret_cast`.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.

>From 4eaa9b53fedc1ab755a2ffcc4e8a7b41f4409fea Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 20 Sep 2024 16:31:01 +0200
Subject: [PATCH] [mlir][memref] Add builder that infers `reinterpret_cast`
 result type

Add a convenience builder that infers the result type of `memref.reinterpret_cast`.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type that do not match. The op verifier should probably be made stricter, but that's a larger refactoring that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td        |  4 ++++
 .../IR/BufferDeallocationOpInterface.cpp       |  6 ++++--
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp       | 18 ++++++++++++++++++
 3 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,



More information about the Mlir-commits mailing list