[Mlir-commits] [mlir] [mlir][nvgpu] Improve verifier of `ldmatrix` (PR #77807)
Guray Ozen
llvmlistbot at llvm.org
Thu Jan 11 09:50:07 PST 2024
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/77807
PR improves the verifier of `nvgpu.ldmatrix` Op, so `nvgpu-to-nvvm` lowering does not crash.
>From 8db09fb0b21a30c348dc2f2765acca5f66e74db0 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 11 Jan 2024 18:47:16 +0100
Subject: [PATCH] [mlir][nvgpu] Improve verifier of `ldmatrix`
PR improves the verifier of `nvgpu.ldmatrix` Op, so `nvgpu-to-nvvm` lowering does not crash.
---
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 3 +++
mlir/test/Dialect/NVGPU/invalid.mlir | 8 ++++++++
2 files changed, 11 insertions(+)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c9756ae8fc11ce..b0a4ed1cc2697c 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -321,6 +321,9 @@ LogicalResult LdMatrixOp::verify() {
if (isTranspose && !(elementBitWidth == 16))
return emitError()
<< "nvgpu.ldmatrix transpose works only at 16b granularity";
+ if (resShape.size() != 2) {
+ return emitError() << "results must be 2 dimensional vector";
+ }
if (!(resShape[1] == numElementsPer32b))
return emitError() << "expected vector register shape[1] = "
<< numElementsPer32b;
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 3bffbc78569793..e1949fcfad7ad6 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -40,6 +40,14 @@ func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf
}
// -----
+func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf32> {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{results must be 2 dimensional vector}}
+ %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4xf32>
+ return %a : vector<4xf32>
+}
+// -----
+
func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
// expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}}
More information about the Mlir-commits
mailing list