[Mlir-commits] [mlir] 2491867 - [mlir][nvgpu] Improve verifier of `ldmatrix` (#77807)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 11 23:57:16 PST 2024


Author: Guray Ozen
Date: 2024-01-12T08:57:12+01:00
New Revision: 249186701d5c74a36d5a1c8ccb5de2deac42944a

URL: https://github.com/llvm/llvm-project/commit/249186701d5c74a36d5a1c8ccb5de2deac42944a
DIFF: https://github.com/llvm/llvm-project/commit/249186701d5c74a36d5a1c8ccb5de2deac42944a.diff

LOG: [mlir][nvgpu] Improve verifier of `ldmatrix` (#77807)

PR improves the verifier of `nvgpu.ldmatrix` Op, so `nvgpu-to-nvvm`
lowering does not crash.

Added: 
    

Modified: 
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Dialect/NVGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..b0a4ed1cc2697c 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -321,6 +321,9 @@ LogicalResult LdMatrixOp::verify() {
   if (isTranspose && !(elementBitWidth == 16))
     return emitError()
            << "nvgpu.ldmatrix transpose works only at 16b granularity";
+  if (resShape.size() != 2) {
+    return emitError() << "results must be 2 dimensional vector";
+  }
   if (!(resShape[1] == numElementsPer32b))
     return emitError() << "expected vector register shape[1] = "
                        << numElementsPer32b;

diff  --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 3bffbc78569793..e1949fcfad7ad6 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -40,6 +40,14 @@ func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) ->  vector<4x1xf
 }
 // -----
 
+func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) ->  vector<4x1xf32> {
+  %c0  = arith.constant 0 : index
+  // expected-error @+1 {{results must be 2 dimensional vector}}
+  %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4xf32>
+  return %a : vector<4xf32>
+}
+// -----
+
 func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
   // expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}}


        


More information about the Mlir-commits mailing list