[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)

Jakub Kuderski llvmlistbot at llvm.org
Sun Feb 9 13:51:06 PST 2025


================
@@ -59,6 +59,59 @@ LogicalResult PackedStochRoundFp8Op::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// FatRawBuferCastOp
+//===----------------------------------------------------------------------===//
+
+/// Convert the type `source` to one with the same sizes and strides - and
+/// offset, unless `stripOffset` is true, in which case the offset is reset to
+/// 0, If the offset should be reset but the layout of `source` isn't either the
+/// identity layout or a strided layout, this function fails.
+static FailureOr<MemRefType> getFatRawBufferTypeLike(MemRefType source,
+                                                     bool resetOffset) {
+  MLIRContext *ctx = source.getContext();
+  MemRefType::Builder mb(source);
+  mb.setMemorySpace(
+      amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
+  MemRefLayoutAttrInterface layout = source.getLayout();
+  if (resetOffset && !layout.isIdentity()) {
+    auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
+    if (!stridedLayout)
+      return failure();
+    mb.setLayout(StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()));
+  }
+  return (MemRefType)(mb);
+}
+
+LogicalResult FatRawBufferCastOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Adaptor adaptor(operands, attributes, properties, regions);
+  auto sourceType =
+      dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
+  if (!sourceType)
+    return failure();
+  FailureOr<MemRefType> resultType =
+      getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset());
+  if (failed(resultType))
+    return failure();
+  inferredReturnTypes = SmallVector<Type>{*resultType};
+  return success();
+}
+
+LogicalResult FatRawBufferCastOp::verify() {
+  FailureOr<MemRefType> expectedResultType =
+      getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
+  if (failed(expectedResultType))
+    return emitOpError("source type ")
+           << getSource().getType() << " can't have its offset reset";
+  if (getResult().getType() != *expectedResultType)
+    return emitOpError("expected result type to be ")
+           << *expectedResultType << " but got " << getResult().getType();
----------------
kuhar wrote:

Should we add test that cover these diagnostics?

https://github.com/llvm/llvm-project/pull/125594


More information about the Mlir-commits mailing list