[Mlir-commits] [mlir] [mlir][nvgpu]add dim check test to nvgpu.mma op. (PR #122864)

lonely eagle llvmlistbot at llvm.org
Mon Jan 13 22:49:52 PST 2025


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/122864

as tile.
In the one-dimensional case, mlir-opt crashes directly, and I added more checks in nvgpu.mmaOp 's verify.

>From 19afe0ccfdbcc57f55c3f79faa17e7bf5fef5551 Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 14 Jan 2025 14:45:27 +0800
Subject: [PATCH] add dim check test to nvgpu.mma op.

---
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 12 +++++++++++
 mlir/test/Dialect/NVGPU/invalid.mlir       | 24 ++++++++++++++++++++++
 2 files changed, 36 insertions(+)

diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index de9bbcbace6924..a027350e8a5f70 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -203,6 +203,18 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
   // Basic verification
   //
 
+  if (aShape.size() != 2) {
+    return op->emitError() << "matrixA must be 2 dimensional vector";
+  }
+
+  if (bShape.size() != 2) {
+    return op->emitError() << "matrixB must be 2 dimensional vector";
+  }
+
+  if (cShape.size() != 2) {
+    return op->emitError() << "matrixC must be 2 dimensional vector";
+  }
+
   auto [m, n, k] = mmaShape;
 
   // verify warp-wide size for vector a
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index f7db1140794e54..b5bfbe9ff27b79 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -354,3 +354,27 @@ func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
   // expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
   %out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
 }
+
+// -----
+
+func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixA must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixB must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<4xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixC must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}



More information about the Mlir-commits mailing list