[Mlir-commits] [mlir] 0259f92 - [mlir][memref] Add builder that infers `reinterpret_cast` result type (#109432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 25 00:33:19 PDT 2024


Author: Matthias Springer
Date: 2024-09-25T09:33:15+02:00
New Revision: 0259f92711599c45d229fb12f6f51915fffac6bd

URL: https://github.com/llvm/llvm-project/commit/0259f92711599c45d229fb12f6f51915fffac6bd
DIFF: https://github.com/llvm/llvm-project/commit/0259f92711599c45d229fb12f6f51915fffac6bd.diff

LOG: [mlir][memref] Add builder that infers `reinterpret_cast` result type (#109432)

Add a convenience builder that infers the result type of
`memref.reinterpret_cast`.

Note: It is not possible to remove the result type from all builder
overloads because this op currently also allows certain
operand/attribute + result type combinations that do not match. The op
verifier should probably be made stricter, but that's a larger change
that requires additional `memref.cast` ops in some places that build
`reinterpret_cast` ops.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2ff9d612a5efa7..c50df6ccd9aa56 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1407,6 +1407,10 @@ def MemRef_ReinterpretCastOp
       "OpFoldResult":$offset, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a ReinterpretCastOp and infer the result type.
+    OpBuilder<(ins "Value":$source, "OpFoldResult":$offset,
+      "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a ReinterpretCastOp with static entries.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "int64_t":$offset, "ArrayRef<int64_t>":$sizes,

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index b197786c320548..51dfd84d9ac601 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -197,8 +197,10 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate(
     // that we can call extract_strided_metadata on it.
     if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType()))
       memref = builder.create<memref::ReinterpretCastOp>(
-          loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref,
-          0, SmallVector<int64_t>{}, SmallVector<int64_t>{});
+          loc, memref,
+          /*offset=*/builder.getIndexAttr(0),
+          /*sizes=*/ArrayRef<OpFoldResult>{},
+          /*strides=*/ArrayRef<OpFoldResult>{});
 
     // Use the `memref.extract_strided_metadata` operation to get the base
     // memref. This is needed because the same MemRef that was produced by the

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9c021d3613f1c8..75b9729e63648c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1832,6 +1832,24 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
         b.getDenseI64ArrayAttr(staticStrides));
 }
 
+void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                              Value source, OpFoldResult offset,
+                              ArrayRef<OpFoldResult> sizes,
+                              ArrayRef<OpFoldResult> strides,
+                              ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = cast<BaseMemRefType>(source.getType());
+  SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+  SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
+  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
+  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  auto stridedLayout = StridedLayoutAttr::get(
+      b.getContext(), staticOffsets.front(), staticStrides);
+  auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
+                                    stridedLayout, sourceType.getMemorySpace());
+  build(b, result, resultType, source, offset, sizes, strides, attrs);
+}
+
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
                               MemRefType resultType, Value source,
                               int64_t offset, ArrayRef<int64_t> sizes,


        


More information about the Mlir-commits mailing list