[Mlir-commits] [mlir] [mlir][memref] Add builder that infers `reinterpret_cast` result type (PR #109432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 20 07:43:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add a convenience builder that infers the result type of `memref.reinterpret_cast`.

Note: It is not possible to remove the result type from all builder overloads because this op currently also allows certain operand/attribute + result type combinations that do not match. The op verifier should probably be made stricter, but that's a larger change that requires additional `memref.cast` ops in some places that build `reinterpret_cast` ops.

---
Full diff: https://github.com/llvm/llvm-project/pull/109432.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+4) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp (+4-2) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,

``````````

</details>


https://github.com/llvm/llvm-project/pull/109432


More information about the Mlir-commits mailing list