[Mlir-commits] [mlir] 14d79af - [mlir][NVGPU] nvgpu.mmasync on F32 through TF32

Thomas Raoux llvmlistbot at llvm.org
Mon Aug 1 16:24:41 PDT 2022


Author: Manish Gupta
Date: 2022-08-01T23:23:27Z
New Revision: 14d79afeae63d78de9483f750fafaba13c7ae2dc

URL: https://github.com/llvm/llvm-project/commit/14d79afeae63d78de9483f750fafaba13c7ae2dc
DIFF: https://github.com/llvm/llvm-project/commit/14d79afeae63d78de9483f750fafaba13c7ae2dc.diff

LOG: [mlir][NVGPU] nvgpu.mmasync on F32 through TF32

Adds optional attribute to support tensor cores on F32 datatype by lowering to `mma.sync` with TF32 operands. Since, TF32 is not a native datatype in LLVM we are adding `tf32Enabled` as an attribute to allow the IR to be aware of `MmaSyncOp` datatype. Additionally, this patch adds placeholders for nvgpu-to-nvgpu transformation targeting higher precision tf32x3.

For mma.sync on f32 input using tensor cores there are two possibilites:
(a) tf32   (1 `mma.sync` per warp-level matrix-multiply-accumulate)
(b) tf32x3 (3 `mma.sync` per warp-level matrix-multiply-accumulate)

Typically, tf32 tensor core acceleration comes at a cost of accuracy from missing precision bits. While f32 has 23 precision bits, tf32 has only 10 precision bits. tf32x3 aims to recover the precision bits by splitting each operand into two tf32 values and issue three `mma.sync` tensor core operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130294

Added: 
    mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
    mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
    mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
    mlir/test/lib/Dialect/NVGPU/CMakeLists.txt
    mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h
    mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/NVGPU/invalid.mlir
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 737752bc2be84..d0dd5a63021dc 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -110,11 +110,22 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
     (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
   ```
   }];
-  let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
-                       AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
+  let arguments = (ins AnyVector:$matrixA, 
+                       AnyVector:$matrixB,
+                       AnyVector:$matrixC, 
+                       I64ArrayAttr:$mmaShape,
+                       OptionalAttr<UnitAttr>:$tf32Enabled
+                       );
 
   let results = (outs AnyVector:$res);
 
+  let builders = [
+    OpBuilder<(ins "Value":$matrixA, 
+                   "Value":$matrixB, 
+                   "Value":$matrixC, 
+                   "ArrayAttr":$mmaShape)>
+  ];
+
   let assemblyFormat = [{
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)

diff  --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h
index abc910da1c267..32888f1abf81d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h
@@ -19,6 +19,10 @@
 namespace mlir {
 namespace nvgpu {
 
+///
+/// Passes
+///
+
 /// Optimizes vectorized accesses to a shared memory buffer specified by
 /// memrefValue. This transformation assumes the following:
 /// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
@@ -41,6 +45,29 @@ namespace nvgpu {
 mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                        Value memrefValue);
 
+///
+/// Rewrites patterns
+///
+
+//===----------------------------------------------------------------------===//
+// NVGPU transformation options exposed as auxiliary structs.
+//===----------------------------------------------------------------------===//
+/// Enum to control the lowering of `nvgpu.mmasync`.
+enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 };
+
+/// Collect patterns to convert mma.sync on f32 input and rewrite
+/// to use tensor cores with user provided level of accuracy:
+/// (a) tf32   (1 mma.sync per warp-level matrix-multiply-accumulate)
+/// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate)
+/// Typically, tf32 tensor core acceleration comes at a cost
+/// of accuracy from missing precision bits. While f32 has 23 precision
+/// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the
+/// precision bits by spliting each operand into two tf32 values
+/// and issue three mma.sync tensor core operations.
+void populateMmaSyncF32ToTF32Patterns(
+    RewritePatternSet &patterns,
+    nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32);
+
 } // namespace nvgpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 682a0d403c550..41f0877e071e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -275,10 +275,14 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
     NVVM::MMATypes ptxTypeB;
     Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
         cType.getElementType(), /*isAccumulator=*/true);
-    if (!ptxTypeC) {
+    if (!ptxTypeC)
       return op->emitError(
           "could not infer the PTX type for the accumulator/result");
-    }
+
+    // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
+    bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
+    if (aType.getElementType().isF32() && !tf32Enabled)
+      return failure();
 
     Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
     if (aType.getElementType().isInteger(8)) {

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 6691193146f2f..26758825bb01a 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -687,8 +687,8 @@ convertContractOpToMmaSync(vector::ContractionOp op,
   int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
   int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
   int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
-  Value matmul = b.create<nvgpu::MmaSyncOp>(
-      op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
+  Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
+                                            b.getI64ArrayAttr({m, n, k}));
   valueMapping[op.getResult()] = matmul;
   return success();
 }

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 1ced01179dd82..8580a84d0aee0 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -91,6 +91,12 @@ LogicalResult DeviceAsyncCopyOp::verify() {
 //===----------------------------------------------------------------------===//
 // NVGPU_MmaSyncOp
 //===----------------------------------------------------------------------===//
+void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
+                      ::mlir::OperationState &odsState, Value matrixA,
+                      Value matrixB, Value matrixC, ArrayAttr mmaShape) {
+  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+        mmaShape, UnitAttr());
+}
 
 LogicalResult MmaSyncOp::verify() {
 
@@ -122,6 +128,9 @@ LogicalResult MmaSyncOp::verify() {
   // vector element type
   Type aType = aVector.getElementType();
 
+  // tensor float32 (TF32) enabled
+  bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
+
   // nvgpu.mma.sync shape (per 32 threads or per warp)
   int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
   int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
@@ -163,6 +172,10 @@ LogicalResult MmaSyncOp::verify() {
     return emitOpError() << "expected " << m * n
                          << " warp-wide matrix C elements";
 
+  // verify tf32 tensor cores are enabled for only F32 datatype
+  if (tf32Enabled && !(aType.isF32()))
+    return emitOpError() << "expected tf32 tensor cores only for F32 operands";
+
   //
   // Extended verification
   //

diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt
index 831f39620b02a..afe7d167fc7d9 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRNVGPUTransforms
-  OptimizeSharedMemory.cpp  
+  OptimizeSharedMemory.cpp
+  MmaSyncTF32Transform.cpp  
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU

diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
new file mode 100644
index 0000000000000..4ef93b30978a4
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -0,0 +1,73 @@
+//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
+// operations on f32 input datatype
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::nvgpu;
+
+namespace {
+
+struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
+
+  using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
+
+  MmaSyncF32ToTF32Pattern(MLIRContext *context,
+                          nvgpu::MmaSyncF32Lowering precision)
+      : OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
+        precision(precision) {}
+
+  LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
+                                PatternRewriter &rewrite) const override {
+    Location location = op->getLoc();
+
+    if (op->hasAttr(op.getTf32EnabledAttrName()))
+      return failure();
+
+    if (precision == MmaSyncF32Lowering::Unkown)
+      return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
+                                 "unknown precision level");
+
+    if (precision == MmaSyncF32Lowering::TF32x3)
+      return emitError(location, "TF32x3 is not supported at the moment "
+                                 "for nvgpu.mma.sync on f32 datatype");
+
+    if (precision == MmaSyncF32Lowering::TF32)
+      op.setTf32EnabledAttr(rewrite.getUnitAttr());
+
+    return success();
+  }
+
+private:
+  /// Precision for F32 Tensor Cores (TF32 or TF32x3)
+  nvgpu::MmaSyncF32Lowering precision;
+};
+
+} // namespace
+
+void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
+    RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
+
+  patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
+}

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 55b8df621abd9..aa71a26f81069 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -219,7 +219,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
   // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
   // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
   // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>

diff  --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 5f1894faeb709..7fc84109bb392 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -76,6 +76,13 @@ func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1:
 }
 // -----
 
+func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
 func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
   // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
   %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    

diff  --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
new file mode 100644
index 0000000000000..a8c72262f101b
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: m16n8k4_tf32
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
+  // CHECK: nvgpu.mma.sync
+  // CHECK-SAME: tf32Enabled
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  return %d : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: m16n8k8_tf32
+func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // CHECK: nvgpu.mma.sync
+  // CHECK-SAME: tf32Enabled
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----

diff  --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
new file mode 100644
index 0000000000000..523ba245ab436
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32x3.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32x3" -split-input-file | FileCheck %s
+
+// CHECK-LABEL: m16n8k4_tf32
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
+  // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  return %d : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: m16n8k8_tf32
+func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 7c8d1a709d137..6bc8635757e44 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -5,6 +5,7 @@ add_subdirectory(GPU)
 add_subdirectory(Linalg)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
 add_subdirectory(SPIRV)

diff  --git a/mlir/test/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/test/lib/Dialect/NVGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..2fe031d88f167
--- /dev/null
+++ b/mlir/test/lib/Dialect/NVGPU/CMakeLists.txt
@@ -0,0 +1,21 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRNVGPUTestPasses
+  TestNVGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRAffineDialect
+  MLIRAnalysis
+  MLIRFuncDialect
+  MLIRGPUOps
+  MLIRLLVMDialect
+  MLIRMemRefDialect
+  MLIRNVGPUDialect
+  MLIRNVGPUTransforms
+  MLIRPass
+  MLIRSCFDialect
+  MLIRTransformUtils
+  )
+  

diff  --git a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp
new file mode 100644
index 0000000000000..74a15ba273d86
--- /dev/null
+++ b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp
@@ -0,0 +1,76 @@
+//===- TestNVGPUTransforms.cpp - Test NVGPU transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::nvgpu;
+
+namespace {
+
+struct TestMmaSyncF32ToTF32Patterns
+    : public PassWrapper<TestMmaSyncF32ToTF32Patterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMmaSyncF32ToTF32Patterns)
+
+  StringRef getArgument() const final {
+    return "test-nvgpu-mmasync-f32-to-tf32-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns to convert mma.sync on f32 with tf32 precision";
+  }
+  TestMmaSyncF32ToTF32Patterns() = default;
+  TestMmaSyncF32ToTF32Patterns(const TestMmaSyncF32ToTF32Patterns &pass)
+      : PassWrapper(pass) {}
+
+  Option<std::string> precision{
+      *this, "precision",
+      llvm::cl::desc(
+          "Target nvgpu.mma.sync on f32 input with tf32 or tf32x3 precision"),
+      llvm::cl::init("tf32")};
+
+  MmaSyncF32Lowering tf32Precision =
+      llvm::StringSwitch<MmaSyncF32Lowering>(precision)
+          .Case("tf32", MmaSyncF32Lowering::TF32)
+          .Case("tf32x3", MmaSyncF32Lowering::TF32x3)
+          .Default(MmaSyncF32Lowering::Unkown);
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestNvgpuLowerings() {
+  PassRegistration<TestMmaSyncF32ToTF32Patterns>();
+}
+
+} // namespace test
+} // namespace mlir
\ No newline at end of file

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 87acd7361b73b..59036b1f467bc 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -20,6 +20,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLinalgTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
+    MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
     MLIRSPIRVTestPasses

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3d48ec2987ac6..63c028c595806 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -113,6 +113,7 @@ void registerTestTensorTransforms();
 void registerTestTilingInterface();
 void registerTestTransformDialectInterpreterPass();
 void registerTestVectorLowerings();
+void registerTestNvgpuLowerings();
 } // namespace test
 } // namespace mlir
 
@@ -208,6 +209,7 @@ void registerTestPasses() {
   mlir::test::registerTestTilingInterface();
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
+  mlir::test::registerTestNvgpuLowerings();
 }
 #endif
 


        


More information about the Mlir-commits mailing list