[Mlir-commits] [mlir] 9879807 - [mlir][NvGpu] Fix nvgpu.mma.sync lowering to NVVM for f32, tf32 types

Christopher Bate llvmlistbot at llvm.org
Sun May 8 20:58:56 PDT 2022


Author: Christopher Bate
Date: 2022-05-08T21:49:42-06:00
New Revision: 9879807393d3f502d3cac468c5f6451db872aa5f

URL: https://github.com/llvm/llvm-project/commit/9879807393d3f502d3cac468c5f6451db872aa5f
DIFF: https://github.com/llvm/llvm-project/commit/9879807393d3f502d3cac468c5f6451db872aa5f.diff

LOG: [mlir][NvGpu] Fix nvgpu.mma.sync lowering to NVVM for f32, tf32 types

Adds missing logic in the lowering from NvGPU to NVVM to support fp32
(in an accumulator operand) and tf32 (in multiplicand operand) types.
Fixes logic in one of the helper functions for converting the result
of a mma.sync operation with multiple 8x256bit output tiles, which is
the case for f32 outputs.

Differential Revision: https://reviews.llvm.org/D124533

Added: 
    

Modified: 
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a304e49e58387..3ea4330c64082 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -25,6 +25,8 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
   auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
   Type f64Ty = Float64Type::get(ctx);
   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  Type f32Ty = Float32Type::get(ctx);
+  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
   if (a.getElementType() == f16x2Ty) {
     return LLVM::LLVMStructType::getLiteral(
         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -37,6 +39,15 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
   if (a.getElementType() == f64x2Ty) {
     return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty});
   }
+  if (a.getElementType() == f32x2Ty) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx,
+        SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
+  }
+  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+    return LLVM::LLVMStructType::getLiteral(
+        ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
+  }
   return vectorResultType;
 }
 
@@ -52,10 +63,13 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
   auto structType = intrinsicResultType.dyn_cast<LLVM::LLVMStructType>();
   auto arrayType = resultType.dyn_cast<LLVM::LLVMArrayType>();
   Type i32Ty = rewriter.getI32Type();
+  Type f32Ty = rewriter.getF32Type();
   Type f64Ty = rewriter.getF64Type();
   Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
   Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
   Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
 
   auto makeConst = [&](int32_t index) -> Value {
     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -65,21 +79,31 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
   if (arrayType) {
     SmallVector<Value, 4> elements;
 
-    if (arrayType.getElementType() == f16x2Ty) {
+    // The intrinsic returns 32-bit wide elements in a form which can be
+    // directly bitcasted and inserted into the result vector.
+    if (arrayType.getElementType() == f16x2Ty ||
+        arrayType.getElementType() == f32x1Ty) {
       for (unsigned i = 0; i < structType.getBody().size(); i++) {
-        elements.push_back(rewriter.create<LLVM::ExtractValueOp>(
+        Value el = rewriter.create<LLVM::ExtractValueOp>(
             loc, structType.getBody()[i], intrinsicResult,
-            rewriter.getI64ArrayAttr(i)));
+            rewriter.getI64ArrayAttr(i));
+        el = rewriter.createOrFold<LLVM::BitcastOp>(
+            loc, arrayType.getElementType(), el);
+        elements.push_back(el);
       }
     }
 
-    // The intrinsic returns i32 and f64 values as individual scalars. We need
-    // to extract them from the struct and pack them into vectors.
+    // The intrinsic returns i32, f64, and f32 values as individual scalars,
+    // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
+    // need to extract them from the struct and pack them into the 64-bit wide
+    // rows of the vector result.
     if (arrayType.getElementType() == i32x2Ty ||
-        arrayType.getElementType() == f64x2Ty) {
-      Value vec =
-          rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
+        arrayType.getElementType() == f64x2Ty ||
+        arrayType.getElementType() == f32x2Ty) {
+
       for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) {
+        Value vec =
+            rewriter.create<LLVM::UndefOp>(loc, arrayType.getElementType());
         Value x1 = rewriter.create<LLVM::ExtractValueOp>(
             loc, structType.getBody()[i * 2], intrinsicResult,
             rewriter.getI64ArrayAttr(i * 2));
@@ -90,8 +114,8 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
                                                      x1, makeConst(0));
         vec = rewriter.create<LLVM::InsertElementOp>(loc, vec.getType(), vec,
                                                      x2, makeConst(1));
+        elements.push_back(vec);
       }
-      elements.push_back(vec);
     }
 
     // Create the final vectorized result.
@@ -113,12 +137,15 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
 /// scalars of certain types. This function helps unpack the `vector` arguments
 /// and cast them to the types expected by `nvvm.mma.sync`.
 static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
-                                              Location loc, Value operand) {
+                                              Location loc, Value operand,
+                                              NVVM::MMATypes operandPtxType) {
   SmallVector<Value> result;
   Type i32Ty = rewriter.getI32Type();
   Type f64Ty = rewriter.getF64Type();
+  Type f32Ty = rewriter.getF32Type();
   Type i8Ty = rewriter.getI8Type();
   Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
+  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
   auto arrayTy = operand.getType().cast<LLVM::LLVMArrayType>();
 
   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -127,18 +154,21 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
 
     // For 4xi8 vectors, the intrinsic expects these to be provided as i32
     // scalar types.
-    if (arrayTy.getElementType() == i8x4Ty) {
+    if (arrayTy.getElementType() == i8x4Ty ||
+        (arrayTy.getElementType() == f32x1Ty &&
+         operandPtxType == NVVM::MMATypes::tf32)) {
       result.push_back(
           rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
       continue;
     }
 
-    // For some element types (i32, f64), we need to unpack the inner
+    // For some element types (i32, f32, f64), we need to unpack the inner
     // vector/array type as well because the intrinsic expects individual
     // scalars to be provided.
     VectorType innerArrayTy = arrayTy.getElementType().dyn_cast<VectorType>();
     if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty ||
-                         innerArrayTy.getElementType() == f64Ty)) {
+                         innerArrayTy.getElementType() == f64Ty ||
+                         innerArrayTy.getElementType() == f32Ty)) {
       for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
            idx < innerSize; idx++) {
         result.push_back(rewriter.create<LLVM::ExtractElementOp>(
@@ -229,37 +259,47 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     // Get the shapes of the MMAMatrix type being used. The shapes will
     // choose which intrinsic this op will be lowered to.
     auto aType = op.matrixA().getType().cast<VectorType>();
+    auto cType = op.matrixC().getType().cast<VectorType>();
 
     int64_t m = op.mmaShape()[0].cast<IntegerAttr>().getInt();
     int64_t n = op.mmaShape()[1].cast<IntegerAttr>().getInt();
     int64_t k = op.mmaShape()[2].cast<IntegerAttr>().getInt();
     std::array<int64_t, 3> gemmShape{m, n, k};
 
-    SmallVector<Value> matA =
-        unpackOperandVector(rewriter, loc, adaptor.matrixA());
-    SmallVector<Value> matB =
-        unpackOperandVector(rewriter, loc, adaptor.matrixB());
-    SmallVector<Value> matC =
-        unpackOperandVector(rewriter, loc, adaptor.matrixC());
-
     NVVM::MMATypes ptxTypeA;
     NVVM::MMATypes ptxTypeB;
+    Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
+        cType.getElementType(), /*isAccumulator=*/true);
+    if (!ptxTypeC) {
+      return op->emitError(
+          "could not infer the PTX type for the accumulator/result");
+    }
+
     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
     if (aType.getElementType().isInteger(8)) {
       ptxTypeA = NVVM::MMATypes::s8;
       ptxTypeB = NVVM::MMATypes::s8;
       overflow = NVVM::MMAIntOverflow::satfinite;
-
     } else if (aType.getElementType().isF16()) {
       ptxTypeA = NVVM::MMATypes::f16;
       ptxTypeB = NVVM::MMATypes::f16;
     } else if (aType.getElementType().isF64()) {
       ptxTypeA = NVVM::MMATypes::f64;
       ptxTypeB = NVVM::MMATypes::f64;
+    } else if (aType.getElementType().isF32()) {
+      ptxTypeA = NVVM::MMATypes::tf32;
+      ptxTypeB = NVVM::MMATypes::tf32;
     } else {
       return op->emitError("could not deduce operand PTX types");
     }
 
+    SmallVector<Value> matA =
+        unpackOperandVector(rewriter, loc, adaptor.matrixA(), ptxTypeA);
+    SmallVector<Value> matB =
+        unpackOperandVector(rewriter, loc, adaptor.matrixB(), ptxTypeB);
+    SmallVector<Value> matC =
+        unpackOperandVector(rewriter, loc, adaptor.matrixC(), *ptxTypeC);
+
     Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
     Type intrinsicResTy = inferIntrinsicResultType(
         typeConverter->convertType(op->getResultTypes()[0]));

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
index dfc31f1da8560..345270ab5635d 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir
@@ -24,6 +24,34 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
 
 // -----
 
+// Same as above but with fp32 acumulation type.
+
+// CHECK-LABEL: @m16n8k16_fp16_fp32
+func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // We just need to check the mma instruction and the manipulatin of the result.
+  // CHECK: [[d:%.+]] = nvvm.mma.sync
+  // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n  = 8 : i32}
+  // CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>    
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
+
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>  
+  // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
+  
+  // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
+  // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>      
+  return %d : vector<2x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @m16n8k8_fp16
 func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
   // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>>
@@ -125,3 +153,33 @@ func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   // CHECK: llvm.insertvalue    
   return %a : vector<1x2xf16>
 }
+
+// -----
+
+// CHECK-LABEL: @m16n8k4_tf32
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {  
+  // The A, B operand should be bitcast to i32
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32  
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32
+  // CHECK: llvm.extractvalue
+  // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32
+
+  // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}]
+  // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<tf32>
+  // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
+  // CHECK-SAME: shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}
+  // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>  
+  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
+  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
+  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
+  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
+  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
+  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
+  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
+  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
+  // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
+  return %d : vector<4x1xf32>
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list