[Mlir-commits] [mlir] 334f63e - [mlir][NvGpuToNVVM] Fix missing i4 support for nvgpu.mma.sync

Christopher Bate llvmlistbot at llvm.org
Mon May 23 09:53:22 PDT 2022


Author: Christopher Bate
Date: 2022-05-23T10:52:28-06:00
New Revision: 334f63e7c39f298611d27a2ae27d31e4431be10f

URL: https://github.com/llvm/llvm-project/commit/334f63e7c39f298611d27a2ae27d31e4431be10f
DIFF: https://github.com/llvm/llvm-project/commit/334f63e7c39f298611d27a2ae27d31e4431be10f.diff

LOG: [mlir][NvGpuToNVVM] Fix missing i4 support for nvgpu.mma.sync

This changes adds missing support for the i4 data type. Tests are added
to ensure proper lowering of an nvgpu.mma.sync operation targeting the
16x8x64xi4 and 16x8x32xi4 MMA variants in the NVVM dialect.

Differential Revision: https://reviews.llvm.org/D126092

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5152beaa5f61e..ccf85915e49fa 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -145,7 +145,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
   Type f64Ty = rewriter.getF64Type();
   Type f32Ty = rewriter.getF32Type();
   Type i8Ty = rewriter.getI8Type();
+  Type i4Ty = rewriter.getIntegerType(4);
   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+  Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
   Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
 
@@ -156,6 +158,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
     // scalar types.
     if (arrayTy.getElementType() == i8x4Ty ||
+        arrayTy.getElementType() == i4x8Ty ||
         (arrayTy.getElementType() == f32x1Ty &&
          operandPtxType == NVVM::MMATypes::tf32)) {
       result.push_back(
@@ -281,6 +284,10 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
       ptxTypeA = NVVM::MMATypes::s8;
       ptxTypeB = NVVM::MMATypes::s8;
       overflow = NVVM::MMAIntOverflow::satfinite;
+    } else if (aType.getElementType().isInteger(4)) {
+      ptxTypeA = NVVM::MMATypes::s4;
+      ptxTypeB = NVVM::MMATypes::s4;
+      overflow = NVVM::MMAIntOverflow::satfinite;
     } else if (aType.getElementType().isF16()) {
       ptxTypeA = NVVM::MMATypes::f16;
       ptxTypeB = NVVM::MMATypes::f16;

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 249cb48944c88..7bd02b7413117 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -102,6 +102,54 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
 
 // -----
 
+// CHECK-LABEL: @m16n8k32_i4
+func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32    
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<1 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32  
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+  // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
+  // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
+  // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @m16n8k64_i4
+func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<8xi4>>
+  // CHECK: llvm.bitcast [[el]] : vector<8xi4> to i32
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>>
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
+  // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
+  // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
+  // CHECK-SAME: shape = {k = 64 : i32, m = 16 : i32, n = 8 : i32}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @m8n8k4_f64
 func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
   // CHECK: llvm.extractvalue


        


More information about the Mlir-commits mailing list