[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Feb 9 13:51:06 PST 2025
================
@@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// FatRawBuferCastOp
+//===----------------------------------------------------------------------===//
+
+/// Convert the type `source` to one with the same sizes and strides - and
+/// offset, unless `stripOffset` is true, in which case the offset is reset to
+/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// identity layout or a strided layout, this function fails.
+static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
+ bool resetOffset) {
+ MLIRContext *ctx = source.getContext();
+ MemRefType::Builder mb(source);
+ mb.setMemorySpace(
+ amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
+ MemRefLayoutAttrInterface layout = source.getLayout();
+ if (resetOffset && !layout.isIdentity()) {
+ auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
+ if (!stridedLayout)
+ return failure();
+ mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+ }
+ return (MemRefType)(mb);
+}
+
+LogicalResult FatRawBufferCastOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Adaptor adaptor(operands, attributes, properties, regions);
+ auto sourceType =
+ dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
+ if (!sourceType)
+ return failure();
+ FailureOr<MemRefType> resultType =
+ getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
+ if (failed(resultType))
+ return failure();
+ inferredReturnTypes = SmallVector<Type>{*resultType};
+ return success();
+}
+
+LogicalResult FatRawBufferCastOp::verify() {
+ FailureOr<MemRefType> expectedResultType =
+ getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
+ if (failed(expectedResultType))
+ return emitOpError("source type ")
+ << getSource().getType() << " can't have its offset reset";
+ if (getResult().getType() != *expectedResultType)
+ return emitOpError("expected result type to be ")
+ << *expectedResultType << " but got " << getResult().getType();
----------------
kuhar wrote:
Should we add test that cover these diagnostics?
https://github.com/llvm/llvm-project/pull/125594
More information about the Mlir-commits
mailing list