[Mlir-commits] [mlir] f7d42d5 - [mlir][NVGPU] Verifiers for nvgpu.mma.sync Op

Thomas Raoux llvmlistbot at llvm.org
Wed Jul 13 11:59:07 PDT 2022


Author: Manish Gupta
Date: 2022-07-13T18:57:07Z
New Revision: f7d42d5149ddf564dd87c77b7531ef83ddfad622

URL: https://github.com/llvm/llvm-project/commit/f7d42d5149ddf564dd87c77b7531ef83ddfad622
DIFF: https://github.com/llvm/llvm-project/commit/f7d42d5149ddf564dd87c77b7531ef83ddfad622.diff

LOG: [mlir][NVGPU] Verifiers for nvgpu.mma.sync Op

- Adds verification for `nvgpu.mma.sync` op
- Adds tests to `mlir/test/Dialect/NVGPU/invalid.mlir`
- `nvgpu.mma.sync` verifier caught a bug and triggered a failure in m16n8k4_tf32_f32 variant in `mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir`
     - The output shape of vector holding thread-level accumulators was inconsistent  and fixed in this change

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D129400

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/NVGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 52338c0bc0cf3..ec0c18bd74824 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -81,7 +81,10 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
   }];
 }
 
-def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
+def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
+    NoSideEffect,
+    PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>,
+  ]> {
   let description = [{
   The `nvgpu.mma.sync` op represents the distributed form of a collective
   matrix-multiply-and-accumulate (mma) operation that is compatible with
@@ -112,6 +115,8 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
+
+  let hasVerifier = 1;
 }
 
 

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index c31a168cd2103..ac937e0fea0eb 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -88,5 +88,103 @@ LogicalResult DeviceAsyncCopyOp::verify() {
   return success();
 }
 
+LogicalResult MmaSyncOp::verify() {
+
+  // Fundamental tensor core mma.sync op
+  // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
+  // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
+  // verification for mma.sync covering various shapes and data types is based
+  // on the fundamental tensor core operionation.
+  constexpr int kThreads = 32; // 32 threads per warp
+  int64_t shapeM = 8;
+  int64_t shapeN = 8;
+  int64_t shapeK; // set based on data type (128b for all data types except F64)
+
+  // Number of elements A, B, and C per thread per fundamental tensor core tile
+  int64_t numElementA;    // set based on data type (32b except F64)
+  int64_t numElementB;    // set based on data type (32b except F64)
+  int64_t numElementC{2}; // two accumulator elements per fundamental tile
+
+  // nvgpu.mma.sync vector operands (per thread)
+  auto aVector = getMatrixA().getType().cast<VectorType>();
+  auto bVector = getMatrixB().getType().cast<VectorType>();
+  auto cVector = getMatrixC().getType().cast<VectorType>();
+
+  // vector shapes
+  ArrayRef<int64_t> aShape = aVector.getShape();
+  ArrayRef<int64_t> bShape = bVector.getShape();
+  ArrayRef<int64_t> cShape = cVector.getShape();
+
+  // vector element type
+  Type aType = aVector.getElementType();
+
+  // nvgpu.mma.sync shape (per 32 threads or per warp)
+  int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
+  int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
+  int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
+
+  if (aType.isF64()) {
+    // exception to 8-by-8-128b fundamental tensor core tile size
+    shapeK = 4;
+    numElementA = 1;
+    numElementB = 1;
+  } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
+             aType.isInteger(8) || aType.isInteger(4)) {
+    // 8-by-8-128b fundamental tensor core tile size
+    int operandBitwidth = aType.getIntOrFloatBitWidth();
+    shapeK = 128 / operandBitwidth;     // 128b wide shapeK
+    numElementA = 32 / operandBitwidth; // 32b wide operand A
+    numElementB = 32 / operandBitwidth; // 32b wide operand B
+  } else {
+    return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
+                          "supported by nvgpu.mma.sync";
+  }
+
+  //
+  // Basic verification
+  //
+
+  // verify warp-wide size for vector a
+  if (aShape[0] * aShape[1] * kThreads != m * k)
+    return emitOpError() << "expected " << m * k
+                         << " warp-wide matrix A elements";
+
+  // verify warp-wide size for vector b
+  if (bShape[0] * bShape[1] * kThreads != k * n)
+    return emitOpError() << "expected " << k * n
+                         << " warp-wide matrix B elements";
+
+  // verify warp-wide size for vector c
+  if (cShape[0] * cShape[1] * kThreads != m * n)
+    return emitOpError() << "expected " << m * n
+                         << " warp-wide matrix C elements";
+
+  //
+  // Extended verification
+  //
+
+  // tiles of fundamental tensor core operations
+  int64_t mTile = m / shapeM;
+  int64_t nTile = n / shapeN;
+  int64_t kTile = k / shapeK;
+
+  // verify shape of aVector
+  if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
+    return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
+                         << " x " << numElementA << ")";
+
+  // verify shape of bVector
+  if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
+    return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
+                         << " x " << numElementB << ")";
+
+  // verify shape of cVector
+  if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
+    return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
+                         << " x " << numElementC << ")";
+
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 0d0f7845e24a9..55b8df621abd9 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -205,7 +205,7 @@ func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
 // -----
 
 // CHECK-LABEL: @m16n8k4_tf32
-func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {  
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
   // The A, B operand should be bitcast to i32
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32  
@@ -219,17 +219,22 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>  
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
-  return %d : vector<4x1xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
+
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>  
+  // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
+  
+  // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
+  // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>   
+  return %d : vector<2x2xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 7a9acb4727d45..6be9cda42ccb3 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -1,4 +1,73 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
+func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected 256 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected 128 warp-wide matrix B elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
+  // expected-error @+1 {{expected 128 warp-wide matrix C elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>    
+  return %d : vector<2x4xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+  // expected-error @+1 {{expected 32 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>    
+  return %d : vector<1x2xf64>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
+  // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>    
+  return %d : vector<2x1xf64>
+}
+// -----
+
+func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // expected-error @+1 {{expected 256 warp-wide matrix B elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+// -----
+
+func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+// -----
 
 func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
   // expected-error @+1 {{destination memref must have memory space 3}}


        


More information about the Mlir-commits mailing list