[Mlir-commits] [mlir] [MLIR][NVGPU] Improve and Cleanup verifier of TMA OPs (PR #70923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 7 03:05:10 PST 2023
================
@@ -335,34 +335,78 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//
-LogicalResult TmaAsyncLoadOp::verify() {
- // Destination memref
- auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
+std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
+ Operation *op, nvgpu::TensorMapDescriptorType descType,
+ std::optional<MemRefType> memrefType = std::nullopt) {
+ MemRefType descMemref = descType.getTensor();
+ // Limitation
+ if (descType.getInterleave() != TensorMapInterleaveKind::INTERLEAVE_NONE)
+ return op->emitError() << "Interleave options are not supported yet.";
+
+ // Address space check for shared memory check
+ if (!NVGPUDialect::hasSharedMemoryAddressSpace(descMemref)) {
+ return op->emitError() << "the tensor map descriptor has incorrect address "
+ "space, it must be shared memory address space.";
+ }
+ // Support only static shape for the time being
+ if (!descMemref.hasStaticShape())
+ return op->emitError() << "the tensor map descriptor must be static shaped";
+
+ // No verification if memref type is not provided
+ if (!memrefType.has_value())
+ return std::nullopt;
+
+ MemRefType dstMemref = memrefType.value();
+
+ // Check element type
+ if (descMemref.getElementType() != dstMemref.getElementType()) {
+ return op->emitError() << "the element type of tensor map descriptor and "
+ "memref must be same";
+ }
+
if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) {
- return emitError()
- << "The operation stores data to shared memory, but "
- "the destination memref does not have a memory space of "
- << NVGPUDialect::kSharedMemoryAddressSpace;
+ return op->emitError() << "the destination memref has incorrect address "
+ "space, it must be shared memory address space.";
}
- if (getCoordinates().size() > 5) {
- return emitError() << "Maximum 5 coordinates are supported.";
+ if (!dstMemref.hasStaticShape())
+ return op->emitError() << "the destination memref must be static shaped";
----------------
qcolombet wrote:
I would move this check close to its `descMemref` counterpart.
https://github.com/llvm/llvm-project/pull/70923
More information about the Mlir-commits
mailing list