[Mlir-commits] [mlir] [mlir][nvgpu] Remove strict verifiers on `warpgroup.generate.descriptor` (PR #69935)

Guray Ozen llvmlistbot at llvm.org
Mon Oct 23 08:21:00 PDT 2023


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/69935

This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.

>From b7d18a288f446513fb173e0356d887c8f12de7fb Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 23 Oct 2023 17:20:26 +0200
Subject: [PATCH] [mlir][nvgpu] Remove strict verifiers on
 `warpgroup.generate.descriptor`

This PR relaxes some rules in the verifier. I found this to be overly restrictive. It's certainly possible to work around these rules, for example one way is to generate additional subview and etc., but this just bloats the IR.
---
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index f5b02fe1b515591..15eeba2839479d8 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -375,15 +375,9 @@ LogicalResult WarpgroupGenerateDescriptorOp::verify() {
   MemRefType memrefType = getTensor().getType();
   MemRefType tensorMapType = getTensorMap().getType().getTensor();
 
-  if (memrefType != tensorMapType)
-    return emitError() << "memref and tensor map type mismatch";
-
   if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape())
     return emitError() << "supports only static shapes";
 
-  if (memrefType.getRank() != 2)
-    return emitError() << "supports only 2d memref is supported for now";
-
   if (getTensorMap().getType().getSwizzle() !=
       TensorMapSwizzleKind::SWIZZLE_128B) {
     return emitError() << "supports only "



More information about the Mlir-commits mailing list