[Mlir-commits] [mlir] [mlir][amdgpu] Adds make_dma_gather_base (PR #171857)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Dec 11 09:06:37 PST 2025


================
@@ -755,28 +755,52 @@ LogicalResult TransposeLoadOp::verify() {
 // MakeDmaBaseOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult MakeDmaBaseOp::verify() {
-
-  auto ldsType = cast<MemRefType>(getLds().getType());
-  auto globalType = cast<MemRefType>(getGlobal().getType());
+template <typename BaseOp>
+static LogicalResult verifyBase(BaseOp op) {
+  auto ldsType = cast<MemRefType>(op.getLds().getType());
+  auto globalType = cast<MemRefType>(op.getGlobal().getType());
   if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "lds memref must have workgroup address space attribute.");
   if (!hasGlobalMemorySpace(globalType.getMemorySpace()))
-    return emitOpError(
+    return op.emitOpError(
         "global memref must have global address space attribute.");
 
   Type elementType = ldsType.getElementType();
   unsigned width = elementType.getIntOrFloatBitWidth();
 
   if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
-    return emitOpError(
+    return op.emitOpError(
                "element type must be 1, 2, 4, or 8 bytes long but type was ")
            << width << " bits long.";
+  return success();
+}
+
+LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); }
+
+//===----------------------------------------------------------------------===//
+// MakeGatherDmaBaseOp
+//===----------------------------------------------------------------------===//
 
+LogicalResult
+TDMGatherBaseType::verify(function_ref<InFlightDiagnostic()> emitError,
+                          Type elementType, Type indexType) {
+  unsigned width = elementType.getIntOrFloatBitWidth();
+  if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
+    return emitError()
+           << "element type must be 1, 2, 4, or 8 bytes wide but type "
+           << elementType << " is " << width / 8 << " bytes wide.";
+  MLIRContext *ctx = elementType.getContext();
+  Type i16 = IntegerType::get(ctx, 32);
+  Type i32 = IntegerType::get(ctx, 16);
+  if (!llvm::is_contained<Type>({i16, i32}, indexType))
----------------
amd-eochoalo wrote:

https://github.com/llvm/llvm-project/pull/171857/commits/23a3b45fbf7e9e0d8d6b2458d8a286224184ee68

https://github.com/llvm/llvm-project/pull/171857


More information about the Mlir-commits mailing list