[Mlir-commits] [mlir] [MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs (PR #70923)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 03:05:10 PST 2023


================
@@ -335,34 +335,78 @@ LogicalResult LdMatrixOp::verify() {
 // NVGPU_TmaAsyncLoadOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult TmaAsyncLoadOp::verify() {
-  // Destination memref
-  auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
+std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
+    Operation *op, nvgpu::TensorMapDescriptorType descType,
+    std::optional<MemRefType> memrefType = std::nullopt) {
+  MemRefType descMemref = descType.getTensor();
+  // Limitation
+  if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
+    return op->emitError() << "Interleave options are not supported yet.";
+
+  // Address space check for shared memory check
+  if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
+    return op->emitError() << "the tensor map descriptor has incorrect address "
+                              "space, it must be shared memory address space.";
+  }
+  // Support only static shape for the time being
+  if (!descMemref.hasStaticShape())
+    return op->emitError() << "the tensor map descriptor must be static shaped";
+
+  // No verification if memref type is not provided
+  if (!memrefType.has_value())
+    return std::nullopt;
+
+  MemRefType dstMemref = memrefType.value();
+
+  // Check element type
+  if (descMemref.getElementType() != dstMemref.getElementType()) {
+    return op->emitError() << "the element type of tensor map descriptor and "
+                              "memref must be same";
+  }
+
   if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
-    return emitError()
-           << "The operation stores data to shared memory, but "
-              "the destination memref does not have a memory space of "
-           << NVGPUDialect::kSharedMemoryAddressSpace;
+    return op->emitError() << "the destination memref has incorrect address "
+                              "space, it must be shared memory address space.";
   }
-  if (getCoordinates().size() > 5) {
-    return emitError() << "Maximum 5 coordinates are supported.";
+  if (!dstMemref.hasStaticShape())
+    return op->emitError() << "the destination memref must be static shaped";
+
+  if (dstMemref.getRank() != descMemref.getRank()) {
+    return op->emitError() << "the shape of tensor map descriptor and "
+                              "memref must have same rank";
+  }
+  if (!descMemref.getShape().equals(dstMemref.getShape())) {
+    return op->emitError() << "memref and tensor map shapes mismatch "
+                           << descMemref << " != " << dstMemref;
   }
-  if (getCoordinates().size() != size_t(dstMemref.getRank())) {
-    return emitError() << "Destination memref rank is "
-                       << size_t(dstMemref.getRank()) << " but there are  "
-                       << getCoordinates().size()
-                       << " coordinates. They must match.";
+
+  return std::nullopt;
+}
+
+LogicalResult TmaAsyncLoadOp::verify() {
+  std::optional<InFlightDiagnostic> error = verifyTmaDescriptorWithMemref(
+      *this, getTensorMapDescriptor().getType(), getDst().getType());
+  if (error.has_value())
+    return error.value();
+
+  if (getCoordinates().size() > kMaxTMATensorDimension) {
+    return emitError() << "Maximum " << kMaxTMATensorDimension
+                       << " coordinates are supported.";
   }
----------------
qcolombet wrote:

Shouldn't we check that the coordiantes size is equal to the rank of the memref?

I remember you said you wanted to relax that, but if we do I think we have to update the related op's documentation to explain the semantic.

https://github.com/llvm/llvm-project/pull/70923


More information about the Mlir-commits mailing list