[Mlir-commits] [mlir] 13f4e88 - Revert "Revert "[mlir][Transform] Add support for mma.sync m16n8k16 f16 rewrite." and "[mlir][Transform] Introduce nvgpu transform extensions""

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jun 27 23:50:13 PDT 2023


Author: Nicolas Vasilache
Date: 2023-06-28T06:50:05Z
New Revision: 13f4e889c55288180dfef494c24a123356517d92

URL: https://github.com/llvm/llvm-project/commit/13f4e889c55288180dfef494c24a123356517d92
DIFF: https://github.com/llvm/llvm-project/commit/13f4e889c55288180dfef494c24a123356517d92.diff

LOG: Revert "Revert "[mlir][Transform] Add support for mma.sync m16n8k16 f16 rewrite." and "[mlir][Transform] Introduce nvgpu transform extensions""

This reverts commit 6506692fe619ef8a1f7c6ea829d9a9eceb31622d.

Differential Revision: https://reviews.llvm.org/D153845

Added: 
    mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
    mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
    mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
    mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
    mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
    mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir

Modified: 
    mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/NVGPU/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
index d48de9df4b71b..49a07c997a30c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(IR)
+add_subdirectory(TransformOps)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
 mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU)

diff  --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..d75ae3dd5d017
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS NVGPUTransformOps.td)
+mlir_tablegen(NVGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(NVGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRNVGPUTransformOpsIncGen)

diff  --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
new file mode 100644
index 0000000000000..0c7b9d865aa24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
@@ -0,0 +1,43 @@
+//===- NVGPUTransformOps.h - NVGPU transform ops ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+namespace mlir {
+namespace transform {
+class TransformHandleTypeInterface;
+} // namespace transform
+} // namespace mlir
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+
+namespace nvgpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace nvgpu
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// NVGPU Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc"
+
+#endif // MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H

diff  --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
new file mode 100644
index 0000000000000..168a445b62ccf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -0,0 +1,51 @@
+//===- NVGPUTransformOps.td - NVGPU transform ops ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVGPU_TRANSFORM_OPS
+#define NVGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+def RewriteMatmulAsMmaSyncOp :
+  Op<Transform_Dialect, "nvgpu.rewrite_matmul_as_mma_sync",
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformEachOpTrait, 
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Rewrite a matmul operation on memref to an mma.sync operation on vectors.
+
+    Memory copies with the required access patterns are automatically inserted.
+    Operations that do not have a 1-1 mapping to mma.sync operations are left
+    unchanged.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp linalgOp,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+#endif // NVGPU_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index db15dff136cd1..106d5af5cfb7d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,6 +55,7 @@
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
@@ -137,6 +138,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   gpu::registerTransformDialectExtension(registry);
   linalg::registerTransformDialectExtension(registry);
   memref::registerTransformDialectExtension(registry);
+  nvgpu::registerTransformDialectExtension(registry);
   scf::registerTransformDialectExtension(registry);
   tensor::registerTransformDialectExtension(registry);
   transform::registerPDLExtension(registry);

diff  --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
index 7117520599fa6..63b4d8b99f53f 100644
--- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)
 add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..973e5268389fd
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_dialect_library(MLIRNVGPUTransformOps
+  NVGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU/TransformOps
+
+  DEPENDS
+  MLIRNVGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRArithDialect
+  MLIRIR
+  MLIRLinalgDialect
+  MLIRNVGPUDialect
+  MLIRParser
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRVectorTransforms
+  )

diff  --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
new file mode 100644
index 0000000000000..b08b105d91e19
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -0,0 +1,488 @@
+//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::nvgpu;
+using namespace mlir::transform;
+
+#define DEBUG_TYPE "nvgpu-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+/// Helper struct to encode a pair of row/column indexings in the form of
+/// affine expressions.
+struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
+  RowColIndexing(AffineExpr row, AffineExpr col)
+      : std::pair<AffineExpr, AffineExpr>(row, col) {}
+
+  AffineExpr row() const { return first; };
+  AffineExpr col() const { return second; };
+
+  void print(llvm::raw_ostream &os) const {
+    os << "- indexing: " << first << ", " << second;
+  }
+};
+
+/// Helper struct to provide a simple mapping from matmul operations to the
+/// corresponding mma.sync operation. This is constrained to the case where the
+/// matmul matches the mma.sync operation 1-1.
+struct MmaSyncBuilder {
+  MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
+      : b(b), loc(loc), laneId(laneId) {}
+
+  using IndexCalculator =
+      std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
+
+  /// Create the mma.sync operation corresponding to `linalgOp` along with all
+  /// the supporting load/store and vector operations.
+  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
+
+private:
+  struct MmaSyncInfo {
+    std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
+    std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
+        vectorShapes;
+    SmallVector<int64_t> mmaShape;
+    bool tf32Enabled;
+  };
+
+  /// Return the specific index calculator for the given `linalgOp` or failure
+  /// if the op is not supported. This is the toplevel switch that should just
+  /// be Tablegen'd in the future.
+  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
+                                             TypeRange elementalTypes);
+
+  //===--------------------------------------------------------------------===//
+  // Instruction-specific row, column indexing expression builders.
+  // These should all be declaratively specified via Tablegen in the future.
+  // The Tablegen specification should be as straightforward as possible to
+  // only model the existing size and type combinations.
+  //===--------------------------------------------------------------------===//
+  //
+  // TODO: Tablegen all this.
+  //===--------------------------------------------------------------------===//
+  // m16n8k4 tf32 case.
+  //===--------------------------------------------------------------------===//
+  /// From the NVIDIA doc:
+  /// groupID           = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  /// row =      groupID            for a0
+  ///            groupID + 8        for a1
+  /// col =  threadIDInGroup
+  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    return {RowColIndexing{groupID, threadIDInGroup},
+            RowColIndexing{groupID + 8, threadIDInGroup}};
+  }
+
+  /// From the NVIDIA doc:
+  /// groupID           = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  /// row =  threadIDInGroup
+  /// col =  groupID
+  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    return {RowColIndexing{threadIDInGroup, groupID}};
+  }
+
+  /// From the NVIDIA doc:
+  /// groupID          = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  /// row =      groupID                            for c0 and c1
+  ///          groupID + 8                          for c2 and c3
+  /// col =  (threadIDInGroup * 2) + (i & 0x1)    for ci   where i = {0,..,3}
+  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
+            RowColIndexing{groupID, threadIDInGroup * 2 + 1},
+            RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
+            RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
+  }
+
+  //===--------------------------------------------------------------------===//
+  // m16n8k16 f16 case.
+  //===--------------------------------------------------------------------===//
+  /// From the NVIDIA doc:
+  /// groupID           = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  ///
+  /// row =      groupID            for ai where  0 <= i < 2 || 4 <= i < 6
+  ///           groupID + 8         Otherwise
+  ///
+  /// col =  (threadIDInGroup * 2) + (i & 0x1)          for ai where i <  4
+  ///        (threadIDInGroup * 2) + (i & 0x1) + 8      for ai where i >= 4
+  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    // clang-format off
+    return {
+      RowColIndexing{groupID, threadIDInGroup * 2 + 0},         // i == 0
+      RowColIndexing{groupID, threadIDInGroup * 2 + 1},         // i == 1
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},     // i == 2
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},     // i == 3
+      RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},     // i == 4
+      RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},     // i == 5
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}  // i == 7
+    };
+    // clang-format on
+  }
+
+  /// From the NVIDIA doc:
+  /// groupID           = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  ///
+  /// row =  (threadIDInGroup * 2) + (i & 0x1)           for bi where i <  2
+  ///        (threadIDInGroup * 2) + (i & 0x1) + 8       for bi where i >= 2
+  ///
+  /// col = groupID
+  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    // clang-format off
+    return {
+      RowColIndexing{threadIDInGroup * 2 + 0, groupID},        // i == 0
+      RowColIndexing{threadIDInGroup * 2 + 1, groupID},        // i == 1
+      RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},    // i == 2
+      RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}     // i == 3
+    };
+    // clang-format on
+  }
+
+  /// From the NVIDIA doc:
+  /// groupID           = %laneid >> 2
+  /// threadIDInGroup = %laneid % 4
+  ///
+  /// row =      groupID                               for ci where i <  2
+  ///          groupID + 8                             for ci where i >= 2
+  ///
+  /// col =  (threadIDInGroup * 2) + (i & 0x1)      for ci where i = {0,..,3}
+  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
+    auto dim = getAffineDimExpr(0, ctx);
+    AffineExpr groupID = dim.floorDiv(4);
+    AffineExpr threadIDInGroup = dim % 4;
+    // clang-format off
+    return {
+      RowColIndexing{groupID, threadIDInGroup * 2 + 0},        // i == 0
+      RowColIndexing{groupID, threadIDInGroup * 2 + 1},        // i == 1
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},    // i == 2
+      RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}     // i == 3
+    };
+    // clang-format on
+  }
+
+  //===--------------------------------------------------------------------===//
+  /// Helper functions to create customizable load and stores operations. The
+  /// specific shapes of each MMA instruction are passed via the
+  /// IndexCalculator callback.
+  //===--------------------------------------------------------------------===//
+  /// Build a list of memref.load operations indexed at `(row, col)` indices
+  /// that make sense for a particular MMA instruction and specified via the
+  /// IndexCalculator callback.
+  SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
+                                      OpFoldResult laneId, Value memref,
+                                      IndexCalculator indexFn);
+
+  /// Perform a distributed load of a vector operand of `vectorShape` for a
+  /// particular MMA instruction whose `(row, col)` indices are specified via
+  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+  /// data that makes sense for the particular MMA operation.
+  /// The `vectorShape` matches existing NVGPU dialect op specification but
+  /// could also be flattened in the future if needed for simplification.
+  Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
+                                      OpFoldResult laneId, Value memref,
+                                      IndexCalculator indexFn,
+                                      ArrayRef<int64_t> vectorShape);
+
+  /// Build a list of memref.store operations indexed at `(row, col)` indices
+  /// that make sense for a particular MMA instruction and specified via the
+  /// IndexCalculator callback.
+  SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
+                                             ValueRange toStore,
+                                             OpFoldResult laneId, Value memref,
+                                             IndexCalculator indexFn);
+
+  /// Perform a distributed store of a vector operand of `vectorShape` for a
+  /// particular MMA instruction whose `(row, col)` indices are specified via
+  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+  /// data that makes sense for the particular MMA operation.
+  /// The `vectorShape` matches existing NVGPU dialect op specification but
+  /// could also be flattened in the future if needed for simplification.
+  SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
+      OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+      Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
+
+  OpBuilder &b;
+  Location loc;
+  OpFoldResult laneId;
+};
+
+//===--------------------------------------------------------------------===//
+/// Helper functions to create customizable load and stores operations. The
+/// specific shapes of each MMA instruction are passed via the
+/// IndexCalculator callback.
+//===--------------------------------------------------------------------===//
+
+template <typename ApplyFn, typename ReduceFn>
+static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
+                                           ReduceFn reduceFn) {
+  VectorType vectorType = vector.getType().cast<VectorType>();
+  auto vectorShape = vectorType.getShape();
+  auto strides = computeStrides(vectorShape);
+  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
+    auto indices = delinearize(idx, strides);
+    reduceFn(applyFn(vector, idx, indices), idx, indices);
+  }
+}
+
+SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
+                                                    OpFoldResult laneId,
+                                                    Value memref,
+                                                    IndexCalculator indexFn) {
+  auto aff = [&](AffineExpr e) {
+    return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+  };
+  SmallVector<Value> res;
+  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
+  for (auto indexing : indexings) {
+    Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+    Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+    auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
+    res.push_back(load);
+  }
+  return res;
+}
+
+Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
+    OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
+    IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+  auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
+
+  Type elementType = getElementTypeOrSelf(memref.getType());
+  auto vt = VectorType::get(vectorShape, elementType);
+  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
+  foreachIndividualVectorElement(
+      res,
+      /*applyFn=*/
+      [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+        return loads[linearIdx];
+      },
+      /*reduceFn=*/
+      [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+        res = b.create<vector::InsertOp>(loc, v, res, indices);
+      });
+
+  return res;
+}
+
+SmallVector<Operation *>
+MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
+                                  ValueRange toStore, OpFoldResult laneId,
+                                  Value memref, IndexCalculator indexFn) {
+  auto aff = [&](AffineExpr e) {
+    return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+  };
+  SmallVector<Operation *> res;
+  for (auto [indexing, val] :
+       llvm::zip_equal(indexFn(b.getContext()), toStore)) {
+    Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+    Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+    Operation *store =
+        b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
+    res.push_back(store);
+  }
+  return res;
+}
+
+SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
+    OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+    Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+  SmallVector<Value> toStore;
+  toStore.reserve(32);
+  foreachIndividualVectorElement(
+      vectorToStore,
+      /*applyFn=*/
+      [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+        return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
+      },
+      /*reduceFn=*/
+      [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+        toStore.push_back(v);
+      });
+  return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
+}
+
+static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
+                  SmallVector<int64_t>>
+makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+                 ArrayRef<int64_t> res) {
+  SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
+  SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
+  SmallVector<int64_t> vres{res.begin(), res.end()};
+  return std::make_tuple(vlhs, vrhs, vres);
+}
+
+FailureOr<MmaSyncBuilder::MmaSyncInfo>
+MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
+                                    TypeRange elementalTypes) {
+  // TODO: Tablegen all this.
+  Type f16 = b.getF16Type();
+  Type f32 = b.getF32Type();
+  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
+      elementalTypes == TypeRange{f32, f32, f32}) {
+    return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
+                                       &MmaSyncBuilder::m16n8k4tf32Rhs,
+                                       &MmaSyncBuilder::m16n8k4tf32Res),
+                       makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
+                       SmallVector<int64_t>{opShape.begin(), opShape.end()},
+                       /*tf32Enabled=*/true};
+  }
+  // This is the version with f16 accumulation.
+  // TODO: version with f32 accumulation.
+  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
+      elementalTypes == TypeRange{f16, f16, f16}) {
+    return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
+                                       &MmaSyncBuilder::m16n8k16f16Rhs,
+                                       &MmaSyncBuilder::m16n8k16f16Res),
+                       makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
+                       SmallVector<int64_t>{opShape.begin(), opShape.end()},
+                       /*tf32Enabled=*/false};
+  }
+  return failure();
+}
+
+FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
+  Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
+  Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
+  Value resMemref = linalgOp.getDpsInitOperand(0)->get();
+  assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+         "expected lhs to be a 2D memref");
+  assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+         "expected rhs to be a 2D memref");
+  assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
+         "expected res to be a 2D memref");
+
+  int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
+  int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
+  int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
+  Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
+  Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
+  Type resType = getElementTypeOrSelf(resMemref.getType());
+
+  FailureOr<MmaSyncInfo> maybeInfo =
+      getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
+  if (failed(maybeInfo))
+    return failure();
+
+  MmaSyncInfo info = *maybeInfo;
+  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
+  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
+  Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
+                                            lhsIndexFn, lhsShape);
+  Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
+                                            rhsIndexFn, rhsShape);
+  Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
+                                            resIndexFn, resShape);
+  res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
+                                   info.tf32Enabled);
+  buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
+                                 resShape);
+  return res.getDefiningOp();
+}
+
+DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
+    transform::TransformRewriter &rewriter, LinalgOp linalgOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  bool fail = true;
+  // TODO: more robust detection of matmulOp, with transposes etc.
+  if (auto matmulOp = isa<linalg::MatmulOp>(linalgOp.getOperation())) {
+    Location loc = linalgOp.getLoc();
+    // TODO: more robust computation of laneId, for now assume a single warp.
+    Value laneId = rewriter.create<gpu::ThreadIdOp>(
+        loc, rewriter.getIndexType(), gpu::Dimension::x);
+    if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
+      fail = false;
+  }
+
+  if (fail) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "unsupported target op: " << linalgOp;
+    diag.attachNote(linalgOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  rewriter.eraseOp(linalgOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class NVGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          NVGPUTransformDialectExtension> {
+public:
+  NVGPUTransformDialectExtension() {
+    declareGeneratedDialect<arith::ArithDialect>();
+    declareGeneratedDialect<affine::AffineDialect>();
+    declareGeneratedDialect<nvgpu::NVGPUDialect>();
+    declareGeneratedDialect<vector::VectorDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+
+void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<NVGPUTransformDialectExtension>();
+}

diff  --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
new file mode 100644
index 0000000000000..241f218c79c57
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+// CHECK: #[[$div4:.*]]  = affine_map<()[s0] -> (s0 floordiv 4)>                                    
+// CHECK: #[[$mod4:.*]] = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK: #[[$div4p8:.*]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK: #[[$map3:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+// CHECK: #[[$map4:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 1)>
+
+// CHECK-LABEL: func.func @matmul_16x8x4xf32_global
+func.func @matmul_16x8x4xf32_global(
+    %A: memref<16x4xf32>, %B: memref<4x8xf32>, %C: memref<16x8xf32>) {
+// CHECK-SAME:                                        %[[VAL_0:.*]]: memref<16x4xf32>,
+// CHECK-SAME:                                        %[[VAL_1:.*]]: memref<4x8xf32>,
+// CHECK-SAME:                                        %[[VAL_2:.*]]: memref<16x8xf32>) {
+
+// CHECK:           %[[TIDX:.*]] = gpu.thread_id  x
+// CHECK:           %[[VAL_4:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_5:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_6:.*]] = memref.load %[[VAL_0]][%[[VAL_4]], %[[VAL_5]]] : memref<16x4xf32>
+// CHECK:           %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
+// CHECK:           %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
+//
+// CHECK:           %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
+// CHECK:           %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
+//
+// CHECK:           %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_19:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_2]][%[[VAL_18]], %[[VAL_19]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_21:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_22:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_2]][%[[VAL_21]], %[[VAL_22]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_24:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_25:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_2]][%[[VAL_24]], %[[VAL_25]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
+// CHECK:           %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK:           %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK:           %[[RES:.*]] = vector.insert %[[VAL_29]], %[[VAL_33]] [1, 1] : f32 into vector<2x2xf32>
+//
+// CHECK:           %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+//
+// CHECK:           %[[VAL_36:.*]] = vector.extract %[[VAL_35]][0, 0] : vector<2x2xf32>
+// CHECK:           %[[VAL_37:.*]] = vector.extract %[[VAL_35]][0, 1] : vector<2x2xf32>
+// CHECK:           %[[VAL_38:.*]] = vector.extract %[[VAL_35]][1, 0] : vector<2x2xf32>
+// CHECK:           %[[VAL_39:.*]] = vector.extract %[[VAL_35]][1, 1] : vector<2x2xf32>
+// CHECK:           %[[VAL_40:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_41:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK:           memref.store %[[VAL_36]], %[[VAL_2]][%[[VAL_40]], %[[VAL_41]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_42:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_43:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK:           memref.store %[[VAL_37]], %[[VAL_2]][%[[VAL_42]], %[[VAL_43]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_44:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_45:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK:           memref.store %[[VAL_38]], %[[VAL_2]][%[[VAL_44]], %[[VAL_45]]] : memref<16x8xf32>
+// CHECK:           %[[VAL_46:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK:           %[[VAL_47:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK:           memref.store %[[VAL_39]], %[[VAL_2]][%[[VAL_46]], %[[VAL_47]]] : memref<16x8xf32>
+// CHECK:           return
+// CHECK:         }
+  linalg.matmul ins(%A, %B: memref<16x4xf32>, memref<4x8xf32>)
+            outs(%C: memref<16x8xf32>)
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_matmul_as_mma_sync %matmul 
+    : (!transform.any_op) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func.func @matmul_16x8x16xf16_global
+func.func @matmul_16x8x16xf16_global(
+    %A: memref<16x16xf16>, %B: memref<16x8xf16>, %C: memref<16x8xf16>) {
+
+  // CHECK-COUNT-8: memref.load {{.*}} : memref<16x16xf16>
+  // CHECK-COUNT-8: vector.insert {{.*}} : f16 into vector<4x2xf16> 
+  // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+  // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16> 
+  // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+  // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
+  //
+  //         CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 16]} 
+  //    CHECK-SAME:   : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  //
+  // CHECK-COUNT-4: vector.extract %{{.*}} : vector<2x2xf16>
+  // CHECK-COUNT-4: memref.store %{{.*}} : memref<16x8xf16>
+  linalg.matmul ins(%A, %B: memref<16x16xf16>, memref<16x8xf16>)
+            outs(%C: memref<16x8xf16>)
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_matmul_as_mma_sync %matmul 
+    : (!transform.any_op) -> ()
+}

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
new file mode 100644
index 0000000000000..6788ccea3a222
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
+    config.unsupported = True

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
new file mode 100644
index 0000000000000..593e4c317e0e4
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
@@ -0,0 +1,239 @@
+// RUN: mlir-opt %s \
+// RUN:  -test-transform-dialect-interpreter \
+// RUN:  -test-transform-dialect-erase-schedule \
+// RUN:  -gpu-kernel-outlining \
+// RUN:  -convert-scf-to-cf \
+// RUN:  -convert-vector-to-llvm \
+// RUN:  -convert-math-to-llvm \
+// RUN:  -expand-strided-metadata \
+// RUN:  -lower-affine \
+// RUN:  -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:  -convert-arith-to-llvm \
+// RUN:  -finalize-memref-to-llvm \
+// RUN:  -convert-func-to-llvm \
+// RUN:  -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:  -gpu-to-llvm \
+// RUN:  -convert-func-to-llvm \
+// RUN:  -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x16xf16>
+!rhs_memref_type = memref<16x8xf16>
+!res_memref_type = memref<16x8xf16>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f16 {
+  %r = arith.index_cast %ridx : index to i32
+  %c = arith.index_cast %cidx : index to i32
+  %strideC = arith.index_cast %strideCidx : index to i32
+  %2 = arith.muli %r, %strideC : i32
+  %3 = arith.addi %c, %2 : i32
+  %4 = arith.sitofp %3 : i32 to f16
+  %factor = arith.constant 64.0 : f16
+  %5 = arith.divf %4, %factor : f16
+  return %5: f16
+}
+
+func.func @print_lhs_as_memref_32(%lhs: !lhs_memref_type) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %M = memref.dim %lhs, %c0 : !lhs_memref_type
+  %N = memref.dim %lhs, %c1 : !lhs_memref_type
+  %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+  scf.for %m = %c0 to %M step %c1 {
+    scf.for %n = %c0 to %N step %c1 {
+      %f16 = memref.load %lhs[%m, %n] : !lhs_memref_type
+      %f32 = arith.extf %f16 : f16 to f32
+      memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+    }
+  }
+  %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+  memref.dealloc %tmp_alloc : memref<?x?xf32>
+  return
+}
+
+func.func @print_rhs_as_memref_32(%rhs: !rhs_memref_type) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %M = memref.dim %rhs, %c0 : !rhs_memref_type
+  %N = memref.dim %rhs, %c1 : !rhs_memref_type
+  %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+  scf.for %m = %c0 to %M step %c1 {
+    scf.for %n = %c0 to %N step %c1 {
+      %f16 = memref.load %rhs[%m, %n] : !rhs_memref_type
+      %f32 = arith.extf %f16 : f16 to f32
+      memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+    }
+  }
+  %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+  memref.dealloc %tmp_alloc : memref<?x?xf32>
+  return
+}
+
+func.func @print_res_as_memref_32(%res: !res_memref_type) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %M = memref.dim %res, %c0 : !res_memref_type
+  %N = memref.dim %res, %c1 : !res_memref_type
+  %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+  scf.for %m = %c0 to %M step %c1 {
+    scf.for %n = %c0 to %N step %c1 {
+      %f16 = memref.load %res[%m, %n] : !res_memref_type
+      %f32 = arith.extf %f16 : f16 to f32
+      memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+    }
+  }
+  %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+  memref.dealloc %tmp_alloc : memref<?x?xf32>
+  return
+}
+
+func.func @main() {
+  %lhs = memref.alloc() : !lhs_memref_type
+  %rhs = memref.alloc() : !rhs_memref_type
+  %res = memref.alloc() : !res_memref_type
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %M = memref.dim %res, %c0 : !res_memref_type
+  %N = memref.dim %res, %c1 : !res_memref_type
+  %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+  %f1 = arith.constant 1.0e+00 : f16
+  %f0 = arith.constant 0.0e+00 : f16
+  %c32 = arith.constant 32 : index
+
+  // Intialize the lhs matrix with a linspace function.
+  scf.for %r = %c0 to %M step %c1 {
+    scf.for %c = %c0 to %K step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f16
+      memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+    }
+  }
+  // Intialize the rhs matrix with a linspace function.
+  scf.for %r = %c0 to %K step %c1 {
+    scf.for %c = %c0 to %N step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+      memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+    }
+  }
+  // Intialize the rhs matrix with a linspace function.
+  scf.for %r = %c0 to %M step %c1 {
+    scf.for %c = %c0 to %N step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+      memref.store %idx, %res[%r, %c] : !res_memref_type
+    }
+  }
+
+  %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf16>
+  %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf16>
+  %ures = memref.cast %res : !res_memref_type to memref<*xf16>
+  gpu.host_register %ulhs : memref<*xf16>
+  gpu.host_register %urhs : memref<*xf16>
+  gpu.host_register %ures : memref<*xf16>
+
+  // Print the memrefs before computation.
+  call @print_lhs_as_memref_32(%lhs) : (!lhs_memref_type) -> ()
+  // CHECK: [0,   0.015625,   0.03125,   0.046875,   0.0625,   0.078125,   0.09375,   0.109375,   0.125,   0.140625,   0.15625,   0.171875,   0.1875,   0.203125,   0.21875,   0.234375], 
+  // CHECK: [0.25,   0.265625,   0.28125,   0.296875,   0.3125,   0.328125,   0.34375,   0.359375,   0.375,   0.390625,   0.40625,   0.421875,   0.4375,   0.453125,   0.46875,   0.484375], 
+  // CHECK: [0.5,   0.515625,   0.53125,   0.546875,   0.5625,   0.578125,   0.59375,   0.609375,   0.625,   0.640625,   0.65625,   0.671875,   0.6875,   0.703125,   0.71875,   0.734375], 
+  // CHECK: [0.75,   0.765625,   0.78125,   0.796875,   0.8125,   0.828125,   0.84375,   0.859375,   0.875,   0.890625,   0.90625,   0.921875,   0.9375,   0.953125,   0.96875,   0.984375], 
+  // CHECK: [1,   1.01562,   1.03125,   1.04688,   1.0625,   1.07812,   1.09375,   1.10938,   1.125,   1.14062,   1.15625,   1.17188,   1.1875,   1.20312,   1.21875,   1.23438], 
+  // CHECK: [1.25,   1.26562,   1.28125,   1.29688,   1.3125,   1.32812,   1.34375,   1.35938,   1.375,   1.39062,   1.40625,   1.42188,   1.4375,   1.45312,   1.46875,   1.48438], 
+  // CHECK: [1.5,   1.51562,   1.53125,   1.54688,   1.5625,   1.57812,   1.59375,   1.60938,   1.625,   1.64062,   1.65625,   1.67188,   1.6875,   1.70312,   1.71875,   1.73438], 
+  // CHECK: [1.75,   1.76562,   1.78125,   1.79688,   1.8125,   1.82812,   1.84375,   1.85938,   1.875,   1.89062,   1.90625,   1.92188,   1.9375,   1.95312,   1.96875,   1.98438], 
+  // CHECK: [2,   2.01562,   2.03125,   2.04688,   2.0625,   2.07812,   2.09375,   2.10938,   2.125,   2.14062,   2.15625,   2.17188,   2.1875,   2.20312,   2.21875,   2.23438], 
+  // CHECK: [2.25,   2.26562,   2.28125,   2.29688,   2.3125,   2.32812,   2.34375,   2.35938,   2.375,   2.39062,   2.40625,   2.42188,   2.4375,   2.45312,   2.46875,   2.48438], 
+  // CHECK: [2.5,   2.51562,   2.53125,   2.54688,   2.5625,   2.57812,   2.59375,   2.60938,   2.625,   2.64062,   2.65625,   2.67188,   2.6875,   2.70312,   2.71875,   2.73438], 
+  // CHECK: [2.75,   2.76562,   2.78125,   2.79688,   2.8125,   2.82812,   2.84375,   2.85938,   2.875,   2.89062,   2.90625,   2.92188,   2.9375,   2.95312,   2.96875,   2.98438], 
+  // CHECK: [3,   3.01562,   3.03125,   3.04688,   3.0625,   3.07812,   3.09375,   3.10938,   3.125,   3.14062,   3.15625,   3.17188,   3.1875,   3.20312,   3.21875,   3.23438], 
+  // CHECK: [3.25,   3.26562,   3.28125,   3.29688,   3.3125,   3.32812,   3.34375,   3.35938,   3.375,   3.39062,   3.40625,   3.42188,   3.4375,   3.45312,   3.46875,   3.48438], 
+  // CHECK: [3.5,   3.51562,   3.53125,   3.54688,   3.5625,   3.57812,   3.59375,   3.60938,   3.625,   3.64062,   3.65625,   3.67188,   3.6875,   3.70312,   3.71875,   3.73438], 
+  // CHECK: [3.75,   3.76562,   3.78125,   3.79688,   3.8125,   3.82812,   3.84375,   3.85938,   3.875,   3.89062,   3.90625,   3.92188,   3.9375,   3.95312,   3.96875,   3.98438]
+
+  call @print_rhs_as_memref_32(%rhs) : (!rhs_memref_type) -> ()
+  // CHECK: [0,   0.015625,   0.03125,   0.046875,   0.0625,   0.078125,   0.09375,   0.109375], 
+  // CHECK: [0.125,   0.140625,   0.15625,   0.171875,   0.1875,   0.203125,   0.21875,   0.234375], 
+  // CHECK: [0.25,   0.265625,   0.28125,   0.296875,   0.3125,   0.328125,   0.34375,   0.359375], 
+  // CHECK: [0.375,   0.390625,   0.40625,   0.421875,   0.4375,   0.453125,   0.46875,   0.484375], 
+  // CHECK: [0.5,   0.515625,   0.53125,   0.546875,   0.5625,   0.578125,   0.59375,   0.609375], 
+  // CHECK: [0.625,   0.640625,   0.65625,   0.671875,   0.6875,   0.703125,   0.71875,   0.734375], 
+  // CHECK: [0.75,   0.765625,   0.78125,   0.796875,   0.8125,   0.828125,   0.84375,   0.859375], 
+  // CHECK: [0.875,   0.890625,   0.90625,   0.921875,   0.9375,   0.953125,   0.96875,   0.984375], 
+  // CHECK: [1,   1.01562,   1.03125,   1.04688,   1.0625,   1.07812,   1.09375,   1.10938], 
+  // CHECK: [1.125,   1.14062,   1.15625,   1.17188,   1.1875,   1.20312,   1.21875,   1.23438], 
+  // CHECK: [1.25,   1.26562,   1.28125,   1.29688,   1.3125,   1.32812,   1.34375,   1.35938], 
+  // CHECK: [1.375,   1.39062,   1.40625,   1.42188,   1.4375,   1.45312,   1.46875,   1.48438], 
+  // CHECK: [1.5,   1.51562,   1.53125,   1.54688,   1.5625,   1.57812,   1.59375,   1.60938], 
+  // CHECK: [1.625,   1.64062,   1.65625,   1.67188,   1.6875,   1.70312,   1.71875,   1.73438], 
+  // CHECK: [1.75,   1.76562,   1.78125,   1.79688,   1.8125,   1.82812,   1.84375,   1.85938], 
+  // CHECK: [1.875,   1.89062,   1.90625,   1.92188,   1.9375,   1.95312,   1.96875,   1.98438]
+
+  call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+  // CHECK: [0,   0.015625,   0.03125,   0.046875,   0.0625,   0.078125,   0.09375,   0.109375], 
+  // CHECK: [0.125,   0.140625,   0.15625,   0.171875,   0.1875,   0.203125,   0.21875,   0.234375], 
+  // CHECK: [0.25,   0.265625,   0.28125,   0.296875,   0.3125,   0.328125,   0.34375,   0.359375], 
+  // CHECK: [0.375,   0.390625,   0.40625,   0.421875,   0.4375,   0.453125,   0.46875,   0.484375], 
+  // CHECK: [0.5,   0.515625,   0.53125,   0.546875,   0.5625,   0.578125,   0.59375,   0.609375], 
+  // CHECK: [0.625,   0.640625,   0.65625,   0.671875,   0.6875,   0.703125,   0.71875,   0.734375], 
+  // CHECK: [0.75,   0.765625,   0.78125,   0.796875,   0.8125,   0.828125,   0.84375,   0.859375], 
+  // CHECK: [0.875,   0.890625,   0.90625,   0.921875,   0.9375,   0.953125,   0.96875,   0.984375], 
+  // CHECK: [1,   1.01562,   1.03125,   1.04688,   1.0625,   1.07812,   1.09375,   1.10938], 
+  // CHECK: [1.125,   1.14062,   1.15625,   1.17188,   1.1875,   1.20312,   1.21875,   1.23438], 
+  // CHECK: [1.25,   1.26562,   1.28125,   1.29688,   1.3125,   1.32812,   1.34375,   1.35938], 
+  // CHECK: [1.375,   1.39062,   1.40625,   1.42188,   1.4375,   1.45312,   1.46875,   1.48438], 
+  // CHECK: [1.5,   1.51562,   1.53125,   1.54688,   1.5625,   1.57812,   1.59375,   1.60938], 
+  // CHECK: [1.625,   1.64062,   1.65625,   1.67188,   1.6875,   1.70312,   1.71875,   1.73438], 
+  // CHECK: [1.75,   1.76562,   1.78125,   1.79688,   1.8125,   1.82812,   1.84375,   1.85938], 
+  // CHECK: [1.875,   1.89062,   1.90625,   1.92188,   1.9375,   1.95312,   1.96875,   1.98438]
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+    linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+                 outs(%res: !res_memref_type)
+
+    gpu.terminator
+  }
+
+
+  // Print the result memref after computation.
+  // This has been verified against other f16 CUDA implementations.
+  call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+  // CHECK: [2.42188,   2.4668,   2.51172,   2.55664,   2.60156,   2.64648,   2.69141,   2.73633], 
+  // CHECK: [6.29688,   6.40625,   6.51172,   6.61719,   6.72656,   6.83594,   6.94141,   7.04688], 
+  // CHECK: [10.1719,   10.3438,   10.5156,   10.6797,   10.8516,   11.0234,   11.1875,   11.3594], 
+  // CHECK: [14.0469,   14.2812,   14.5156,   14.7422,   14.9766,   15.2109,   15.4375,   15.6719], 
+  // CHECK: [17.9219,   18.2188,   18.5156,   18.8125,   19.0938,   19.3906,   19.6875,   19.9844], 
+  // CHECK: [21.7969,   22.1562,   22.5156,   22.875,   23.2188,   23.5781,   23.9375,   24.2969], 
+  // CHECK: [25.6719,   26.0938,   26.5156,   26.9375,   27.3438,   27.7656,   28.1875,   28.6094], 
+  // CHECK: [29.5469,   30.0312,   30.5156,   31,   31.4688,   31.9531,   32.4375,   32.9375], 
+  // CHECK: [33.4375,   33.9688,   34.5,   35.0625,   35.5938,   36.1562,   36.6875,   37.25], 
+  // CHECK: [37.3125,   37.9062,   38.5,   39.125,   39.7188,   40.3438,   40.9375,   41.5625], 
+  // CHECK: [41.1875,   41.8438,   42.5,   43.1875,   43.8438,   44.5312,   45.1875,   45.875], 
+  // CHECK: [45.0625,   45.7812,   46.5,   47.25,   47.9688,   48.7188,   49.4375,   50.1875], 
+  // CHECK: [48.9375,   49.7188,   50.5,   51.3125,   52.0938,   52.9062,   53.6875,   54.5], 
+  // CHECK: [52.8125,   53.6562,   54.5,   55.375,   56.2188,   57.0938,   57.9375,   58.8125], 
+  // CHECK: [56.6875,   57.5938,   58.5,   59.4375,   60.3438,   61.2812,   62.1875,   63.125], 
+  // CHECK: [60.5625,   61.5312,   62.5,   63.5,   64.5,   65.4375,   66.4375,   67.4375]
+
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_matmul_as_mma_sync %matmul 
+    : (!transform.any_op) -> ()
+}

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
new file mode 100644
index 0000000000000..f71e04a500698
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -0,0 +1,178 @@
+// RUN: mlir-opt %s \
+// RUN:   -test-transform-dialect-interpreter \
+// RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC
+
+// CHECK-MMA-SYNC-LABEL: func @main() {
+//       CHECK-MMA-SYNC:   nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled} 
+//  CHECK-MMA-SYNC-SAME:     : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+
+// Tested to run locally in 1.7s.
+
+// RUN: mlir-opt %s \
+// RUN:   -test-transform-dialect-interpreter \
+// RUN:   -test-transform-dialect-erase-schedule \
+// RUN:   -gpu-kernel-outlining \
+// RUN:   -convert-scf-to-cf \
+// RUN:   -convert-vector-to-llvm \
+// RUN:   -convert-math-to-llvm \
+// RUN:   -expand-strided-metadata \
+// RUN:   -lower-affine \
+// RUN:   -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:   -convert-arith-to-llvm \
+// RUN:   -finalize-memref-to-llvm \
+// RUN:   -convert-func-to-llvm \
+// RUN:   -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN:   -gpu-to-llvm \
+// RUN:   -convert-func-to-llvm \
+// RUN:   -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x4xf32>
+!rhs_memref_type = memref<4x8xf32>
+!res_memref_type = memref<16x8xf32>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f32 {
+  %r = arith.index_cast %ridx : index to i32
+  %c = arith.index_cast %cidx : index to i32
+  %strideC = arith.index_cast %strideCidx : index to i32
+  %2 = arith.muli %r, %strideC : i32
+  %3 = arith.addi %c, %2 : i32
+  %4 = arith.sitofp %3 : i32 to f32
+  return %4: f32
+}
+
+func.func @main() {
+  %lhs = memref.alloc() : !lhs_memref_type
+  %rhs = memref.alloc() : !rhs_memref_type
+  %res = memref.alloc() : !res_memref_type
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %M = memref.dim %res, %c0 : !res_memref_type
+  %N = memref.dim %res, %c1 : !res_memref_type
+  %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+  %f1 = arith.constant 1.0e+00 : f32
+  %f0 = arith.constant 0.0e+00 : f32
+  %c32 = arith.constant 32 : index
+
+  // Intialize the lhs matrix with a linspace function.
+  scf.for %r = %c0 to %M step %c1 {
+    scf.for %c = %c0 to %K step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f32
+      memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+    }
+  }
+  // Intialize the rhs matrix with a linspace function.
+  scf.for %r = %c0 to %K step %c1 {
+    scf.for %c = %c0 to %N step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+      memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+    }
+  }
+  // Intialize the rhs matrix with a linspace function.
+  scf.for %r = %c0 to %M step %c1 {
+    scf.for %c = %c0 to %N step %c1 {
+      %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+      memref.store %idx, %res[%r, %c] : !res_memref_type
+    }
+  }
+
+  %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf32>
+  %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf32>
+  %ures = memref.cast %res : !res_memref_type to memref<*xf32>
+  gpu.host_register %ulhs : memref<*xf32>
+  gpu.host_register %urhs : memref<*xf32>
+  gpu.host_register %ures : memref<*xf32>
+
+  // Print the memrefs before computation.
+  call @printMemrefF32(%ulhs) : (memref<*xf32>) -> ()
+  // CHECK: [0,  1,  2,  3],
+  // CHECK: [4,  5,  6,  7],
+  // CHECK: [8,  9, 10, 11],
+  // CHECK: [12, 13, 14, 15],
+  // CHECK: [16, 17, 18, 19],
+  // CHECK: [20, 21, 22, 23],
+  // CHECK: [24, 25, 26, 27],
+  // CHECK: [28, 29, 30, 31],
+  // CHECK: [32, 33, 34, 35],
+  // CHECK: [36, 37, 38, 39],
+  // CHECK: [40, 41, 42, 43],
+  // CHECK: [44, 45, 46, 47],
+  // CHECK: [48, 49, 50, 51],
+  // CHECK: [52, 53, 54, 55],
+  // CHECK: [56, 57, 58, 59],
+  // CHECK: [60, 61, 62, 63]
+
+  call @printMemrefF32(%urhs) : (memref<*xf32>) -> ()
+  // CHECK: [0,  1,  2,  3,  4,  5,  6,  7],
+  // CHECK: [8,  9, 10, 11, 12, 13, 14, 15],
+  // CHECK: [16, 17, 18, 19, 20, 21, 22, 23],
+  // CHECK: [24, 25, 26, 27, 28, 29, 30, 31]
+
+  call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+  // CHECK: [0,   1,   2,   3,   4,   5,   6,   7],
+  // CHECK: [8,   9,  10,  11,  12,  13,  14,  15],
+  // CHECK: [16,  17,  18,  19,  20,  21,  22,  23],
+  // CHECK: [24,  25,  26,  27,  28,  29,  30,  31],
+  // CHECK: [32,  33,  34,  35,  36,  37,  38,  39],
+  // CHECK: [40,  41,  42,  43,  44,  45,  46,  47],
+  // CHECK: [48,  49,  50,  51,  52,  53,  54,  55],
+  // CHECK: [56,  57,  58,  59,  60,  61,  62,  63],
+  // CHECK: [64,  65,  66,  67,  68,  69,  70,  71],
+  // CHECK: [72,  73,  74,  75,  76,  77,  78,  79],
+  // CHECK: [80,  81,  82,  83,  84,  85,  86,  87],
+  // CHECK: [88,  89,  90,  91,  92,  93,  94,  95],
+  // CHECK: [96,  97,  98,  99, 100, 101, 102, 103],
+  // CHECK: [104, 105, 106, 107, 108, 109, 110, 111],
+  // CHECK: [112, 113, 114, 115, 116, 117, 118, 119],
+  // CHECK: [120, 121, 122, 123, 124, 125, 126, 127]
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+    linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+                 outs(%res: !res_memref_type)
+
+    gpu.terminator
+  }
+
+
+  // Print the result memref after computation.
+  call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+
+  // CHECK: [112, 119, 126, 133, 140, 147, 154, 161],
+  // CHECK: [312, 335, 358, 381, 404, 427, 450, 473],
+  // CHECK: [512, 551, 590, 629, 668, 707, 746, 785],
+  // CHECK: [712, 767, 822, 877, 932, 987, 1042, 1097],
+  // CHECK: [912, 983, 1054, 1125, 1196, 1267, 1338, 1409],
+  // CHECK: [1112, 1199, 1286, 1373, 1460, 1547, 1634, 1721],
+  // CHECK: [1312, 1415, 1518, 1621, 1724, 1827, 1930, 2033],
+  // CHECK: [1512, 1631, 1750, 1869, 1988, 2107, 2226, 2345],
+  // CHECK: [1712, 1847, 1982, 2117, 2252, 2387, 2522, 2657],
+  // CHECK: [1912, 2063, 2214, 2365, 2516, 2667, 2818, 2969],
+  // CHECK: [2112, 2279, 2446, 2613, 2780, 2947, 3114, 3281],
+  // CHECK: [2312, 2495, 2678, 2861, 3044, 3227, 3410, 3593],
+  // CHECK: [2512, 2711, 2910, 3109, 3308, 3507, 3706, 3905],
+  // CHECK: [2712, 2927, 3142, 3357, 3572, 3787, 4002, 4217],
+  // CHECK: [2912, 3143, 3374, 3605, 3836, 4067, 4298, 4529],
+  // CHECK: [3112, 3359, 3606, 3853, 4100, 4347, 4594, 4841]
+
+  return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 
+    : (!transform.any_op) -> !transform.any_op
+  transform.nvgpu.rewrite_matmul_as_mma_sync %matmul 
+    : (!transform.any_op) -> ()
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 479e04e7df686..72c565b729d55 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2758,6 +2758,64 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "NVGPUTransformOps",
+    srcs = glob([
+        "lib/Dialect/NVGPU/TransformOps/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/NVGPU/TransformOps/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ArithDialect",
+        ":ArithUtils",
+        ":AffineDialect",
+        ":DialectUtils",
+        ":GPUDialect",
+        ":IR",
+        ":LinalgDialect",
+        ":MemRefDialect",
+        ":NVGPUDialect",
+        ":NVGPUTransformOpsIncGen",
+        ":Support",
+        ":TransformDialect",
+        ":VectorDialect",
+        "//llvm:Support",
+    ],
+)
+
+td_library(
+    name = "NVGPUTransformOpsTdFiles",
+    srcs = glob([
+        "include/mlir/Dialect/NVGPU/TransformOps/*.td",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "NVGPUTransformOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td",
+    deps = [
+        ":NVGPUTransformOpsTdFiles",
+    ],
+)
+
 cc_library(
     name = "NVGPUUtils",
     srcs = ["lib/Dialect/NVGPU/Utils/MMAUtils.cpp"],
@@ -7685,6 +7743,7 @@ cc_library(
         ":NVGPUPassIncGen",
         ":NVGPUToNVVM",
         ":NVGPUTransforms",
+        ":NVGPUTransformOps",
         ":NVVMDialect",
         ":OpenACCDialect",
         ":OpenMPDialect",


        


More information about the Mlir-commits mailing list