[Mlir-commits] [mlir] 13f4e88 - Revert "Revert "[mlir][Transform] Add support for mma.sync m16n8k16 f16 rewrite." and "[mlir][Transform] Introduce nvgpu transform extensions""
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jun 27 23:50:13 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-28T06:50:05Z
New Revision: 13f4e889c55288180dfef494c24a123356517d92
URL: https://github.com/llvm/llvm-project/commit/13f4e889c55288180dfef494c24a123356517d92
DIFF: https://github.com/llvm/llvm-project/commit/13f4e889c55288180dfef494c24a123356517d92.diff
LOG: Revert "Revert "[mlir][Transform] Add support for mma.sync m16n8k16 f16 rewrite." and "[mlir][Transform] Introduce nvgpu transform extensions""
This reverts commit 6506692fe619ef8a1f7c6ea829d9a9eceb31622d.
Differential Revision: https://reviews.llvm.org/D153845
Added:
mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
Modified:
mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/NVGPU/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
index d48de9df4b71b..49a07c997a30c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(IR)
+add_subdirectory(TransformOps)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU)
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..d75ae3dd5d017
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS NVGPUTransformOps.td)
+mlir_tablegen(NVGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(NVGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRNVGPUTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
new file mode 100644
index 0000000000000..0c7b9d865aa24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
@@ -0,0 +1,43 @@
+//===- NVGPUTransformOps.h - NVGPU transform ops ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+namespace mlir {
+namespace transform {
+class TransformHandleTypeInterface;
+} // namespace transform
+} // namespace mlir
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+
+namespace nvgpu {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace nvgpu
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// NVGPU Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc"
+
+#endif // MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
new file mode 100644
index 0000000000000..168a445b62ccf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -0,0 +1,51 @@
+//===- NVGPUTransformOps.td - NVGPU transform ops ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVGPU_TRANSFORM_OPS
+#define NVGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+def RewriteMatmulAsMmaSyncOp :
+ Op<Transform_Dialect, "nvgpu.rewrite_matmul_as_mma_sync",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Rewrite a matmul operation on memref to an mma.sync operation on vectors.
+
+ Memory copies with the required access patterns are automatically inserted.
+ Operations that do not have a 1-1 mapping to mma.sync operations are left
+ unchanged.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp linalgOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+#endif // NVGPU_TRANSFORM_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index db15dff136cd1..106d5af5cfb7d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,6 +55,7 @@
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
@@ -137,6 +138,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
transform::registerPDLExtension(registry);
diff --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
index 7117520599fa6..63b4d8b99f53f 100644
--- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..973e5268389fd
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_dialect_library(MLIRNVGPUTransformOps
+ NVGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU/TransformOps
+
+ DEPENDS
+ MLIRNVGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffineDialect
+ MLIRArithDialect
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRNVGPUDialect
+ MLIRParser
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRVectorTransforms
+ )
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
new file mode 100644
index 0000000000000..b08b105d91e19
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -0,0 +1,488 @@
+//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::nvgpu;
+using namespace mlir::transform;
+
+#define DEBUG_TYPE "nvgpu-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+/// Helper struct to encode a pair of row/column indexings in the form of
+/// affine expressions.
+struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
+ RowColIndexing(AffineExpr row, AffineExpr col)
+ : std::pair<AffineExpr, AffineExpr>(row, col) {}
+
+ AffineExpr row() const { return first; };
+ AffineExpr col() const { return second; };
+
+ void print(llvm::raw_ostream &os) const {
+ os << "- indexing: " << first << ", " << second;
+ }
+};
+
+/// Helper struct to provide a simple mapping from matmul operations to the
+/// corresponding mma.sync operation. This is constrained to the case where the
+/// matmul matches the mma.sync operation 1-1.
+struct MmaSyncBuilder {
+ MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
+ : b(b), loc(loc), laneId(laneId) {}
+
+ using IndexCalculator =
+ std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
+
+ /// Create the mma.sync operation corresponding to `linalgOp` along with all
+ /// the supporting load/store and vector operations.
+ FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
+
+private:
+ struct MmaSyncInfo {
+ std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
+ std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
+ vectorShapes;
+ SmallVector<int64_t> mmaShape;
+ bool tf32Enabled;
+ };
+
+ /// Return the specific index calculator for the given `linalgOp` or failure
+ /// if the op is not supported. This is the toplevel switch that should just
+ /// be Tablegen'd in the future.
+ FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
+ TypeRange elementalTypes);
+
+ //===--------------------------------------------------------------------===//
+ // Instruction-specific row, column indexing expression builders.
+ // These should all be declaratively specified via Tablegen in the future.
+ // The Tablegen specification should be as straightforward as possible to
+ // only model the existing size and type combinations.
+ //===--------------------------------------------------------------------===//
+ //
+ // TODO: Tablegen all this.
+ //===--------------------------------------------------------------------===//
+ // m16n8k4 tf32 case.
+ //===--------------------------------------------------------------------===//
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = groupID for a0
+ /// groupID + 8 for a1
+ /// col = threadIDInGroup
+ static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{groupID, threadIDInGroup},
+ RowColIndexing{groupID + 8, threadIDInGroup}};
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = threadIDInGroup
+ /// col = groupID
+ static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{threadIDInGroup, groupID}};
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = groupID for c0 and c1
+ /// groupID + 8 for c2 and c3
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
+ static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1},
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
+ }
+
+ //===--------------------------------------------------------------------===//
+ // m16n8k16 f16 case.
+ //===--------------------------------------------------------------------===//
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
+ /// groupID + 8 Otherwise
+ ///
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
+ /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
+ static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
+ };
+ // clang-format on
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
+ /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
+ ///
+ /// col = groupID
+ static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
+ RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
+ RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
+ RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
+ };
+ // clang-format on
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ ///
+ /// row = groupID for ci where i < 2
+ /// groupID + 8 for ci where i >= 2
+ ///
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
+ static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ // clang-format off
+ return {
+ RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
+ };
+ // clang-format on
+ }
+
+ //===--------------------------------------------------------------------===//
+ /// Helper functions to create customizable load and stores operations. The
+ /// specific shapes of each MMA instruction are passed via the
+ /// IndexCalculator callback.
+ //===--------------------------------------------------------------------===//
+ /// Build a list of memref.load operations indexed at `(row, col)` indices
+ /// that make sense for a particular MMA instruction and specified via the
+ /// IndexCalculator callback.
+ SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn);
+
+ /// Perform a distributed load of a vector operand of `vectorShape` for a
+ /// particular MMA instruction whose `(row, col)` indices are specified via
+ /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+ /// data that makes sense for the particular MMA operation.
+ /// The `vectorShape` matches existing NVGPU dialect op specification but
+ /// could also be flattened in the future if needed for simplification.
+ Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn,
+ ArrayRef<int64_t> vectorShape);
+
+ /// Build a list of memref.store operations indexed at `(row, col)` indices
+ /// that make sense for a particular MMA instruction and specified via the
+ /// IndexCalculator callback.
+ SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
+ ValueRange toStore,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn);
+
+ /// Perform a distributed store of a vector operand of `vectorShape` for a
+ /// particular MMA instruction whose `(row, col)` indices are specified via
+ /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+ /// data that makes sense for the particular MMA operation.
+ /// The `vectorShape` matches existing NVGPU dialect op specification but
+ /// could also be flattened in the future if needed for simplification.
+ SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
+ OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
+
+ OpBuilder &b;
+ Location loc;
+ OpFoldResult laneId;
+};
+
+//===--------------------------------------------------------------------===//
+/// Helper functions to create customizable load and stores operations. The
+/// specific shapes of each MMA instruction are passed via the
+/// IndexCalculator callback.
+//===--------------------------------------------------------------------===//
+
+template <typename ApplyFn, typename ReduceFn>
+static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
+ ReduceFn reduceFn) {
+ VectorType vectorType = vector.getType().cast<VectorType>();
+ auto vectorShape = vectorType.getShape();
+ auto strides = computeStrides(vectorShape);
+ for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
+ auto indices = delinearize(idx, strides);
+ reduceFn(applyFn(vector, idx, indices), idx, indices);
+ }
+}
+
+SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
+ OpFoldResult laneId,
+ Value memref,
+ IndexCalculator indexFn) {
+ auto aff = [&](AffineExpr e) {
+ return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+ };
+ SmallVector<Value> res;
+ SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
+ for (auto indexing : indexings) {
+ Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+ Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+ auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
+ res.push_back(load);
+ }
+ return res;
+}
+
+Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
+ OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+ auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
+
+ Type elementType = getElementTypeOrSelf(memref.getType());
+ auto vt = VectorType::get(vectorShape, elementType);
+ Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
+ foreachIndividualVectorElement(
+ res,
+ /*applyFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ return loads[linearIdx];
+ },
+ /*reduceFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ res = b.create<vector::InsertOp>(loc, v, res, indices);
+ });
+
+ return res;
+}
+
+SmallVector<Operation *>
+MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
+ ValueRange toStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn) {
+ auto aff = [&](AffineExpr e) {
+ return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+ };
+ SmallVector<Operation *> res;
+ for (auto [indexing, val] :
+ llvm::zip_equal(indexFn(b.getContext()), toStore)) {
+ Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+ Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+ Operation *store =
+ b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
+ res.push_back(store);
+ }
+ return res;
+}
+
+SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
+ OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+ SmallVector<Value> toStore;
+ toStore.reserve(32);
+ foreachIndividualVectorElement(
+ vectorToStore,
+ /*applyFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
+ },
+ /*reduceFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ toStore.push_back(v);
+ });
+ return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
+}
+
+static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
+ SmallVector<int64_t>>
+makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+ ArrayRef<int64_t> res) {
+ SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
+ SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
+ SmallVector<int64_t> vres{res.begin(), res.end()};
+ return std::make_tuple(vlhs, vrhs, vres);
+}
+
+FailureOr<MmaSyncBuilder::MmaSyncInfo>
+MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
+ TypeRange elementalTypes) {
+ // TODO: Tablegen all this.
+ Type f16 = b.getF16Type();
+ Type f32 = b.getF32Type();
+ if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
+ elementalTypes == TypeRange{f32, f32, f32}) {
+ return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
+ &MmaSyncBuilder::m16n8k4tf32Rhs,
+ &MmaSyncBuilder::m16n8k4tf32Res),
+ makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
+ SmallVector<int64_t>{opShape.begin(), opShape.end()},
+ /*tf32Enabled=*/true};
+ }
+ // This is the version with f16 accumulation.
+ // TODO: version with f32 accumulation.
+ if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
+ elementalTypes == TypeRange{f16, f16, f16}) {
+ return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
+ &MmaSyncBuilder::m16n8k16f16Rhs,
+ &MmaSyncBuilder::m16n8k16f16Res),
+ makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
+ SmallVector<int64_t>{opShape.begin(), opShape.end()},
+ /*tf32Enabled=*/false};
+ }
+ return failure();
+}
+
+FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
+ Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
+ Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
+ Value resMemref = linalgOp.getDpsInitOperand(0)->get();
+ assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected lhs to be a 2D memref");
+ assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected rhs to be a 2D memref");
+ assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected res to be a 2D memref");
+
+ int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
+ int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
+ int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
+ Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
+ Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
+ Type resType = getElementTypeOrSelf(resMemref.getType());
+
+ FailureOr<MmaSyncInfo> maybeInfo =
+ getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
+ if (failed(maybeInfo))
+ return failure();
+
+ MmaSyncInfo info = *maybeInfo;
+ auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
+ auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
+ Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
+ lhsIndexFn, lhsShape);
+ Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
+ rhsIndexFn, rhsShape);
+ Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
+ resIndexFn, resShape);
+ res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
+ info.tf32Enabled);
+ buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
+ resShape);
+ return res.getDefiningOp();
+}
+
+DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
+ transform::TransformRewriter &rewriter, LinalgOp linalgOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ bool fail = true;
+ // TODO: more robust detection of matmulOp, with transposes etc.
+ if (auto matmulOp = isa<linalg::MatmulOp>(linalgOp.getOperation())) {
+ Location loc = linalgOp.getLoc();
+ // TODO: more robust computation of laneId, for now assume a single warp.
+ Value laneId = rewriter.create<gpu::ThreadIdOp>(
+ loc, rewriter.getIndexType(), gpu::Dimension::x);
+ if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
+ fail = false;
+ }
+
+ if (fail) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "unsupported target op: " << linalgOp;
+ diag.attachNote(linalgOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ rewriter.eraseOp(linalgOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class NVGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ NVGPUTransformDialectExtension> {
+public:
+ NVGPUTransformDialectExtension() {
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<affine::AffineDialect>();
+ declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<vector::VectorDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+
+void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<NVGPUTransformDialectExtension>();
+}
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
new file mode 100644
index 0000000000000..241f218c79c57
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+// CHECK: #[[$div4:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK: #[[$mod4:.*]] = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK: #[[$div4p8:.*]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK: #[[$map3:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+// CHECK: #[[$map4:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 1)>
+
+// CHECK-LABEL: func.func @matmul_16x8x4xf32_global
+func.func @matmul_16x8x4xf32_global(
+ %A: memref<16x4xf32>, %B: memref<4x8xf32>, %C: memref<16x8xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<16x4xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<4x8xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<16x8xf32>) {
+
+// CHECK: %[[TIDX:.*]] = gpu.thread_id x
+// CHECK: %[[VAL_4:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_5:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]][%[[VAL_4]], %[[VAL_5]]] : memref<16x4xf32>
+// CHECK: %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
+// CHECK: %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
+// CHECK: %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
+// CHECK: %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
+//
+// CHECK: %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
+// CHECK: %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
+// CHECK: %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
+//
+// CHECK: %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_19:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_2]][%[[VAL_18]], %[[VAL_19]]] : memref<16x8xf32>
+// CHECK: %[[VAL_21:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_22:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_2]][%[[VAL_21]], %[[VAL_22]]] : memref<16x8xf32>
+// CHECK: %[[VAL_24:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_25:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_2]][%[[VAL_24]], %[[VAL_25]]] : memref<16x8xf32>
+// CHECK: %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
+// CHECK: %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
+// CHECK: %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL_29]], %[[VAL_33]] [1, 1] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+//
+// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_35]][0, 0] : vector<2x2xf32>
+// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_35]][0, 1] : vector<2x2xf32>
+// CHECK: %[[VAL_38:.*]] = vector.extract %[[VAL_35]][1, 0] : vector<2x2xf32>
+// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_35]][1, 1] : vector<2x2xf32>
+// CHECK: %[[VAL_40:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_41:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_36]], %[[VAL_2]][%[[VAL_40]], %[[VAL_41]]] : memref<16x8xf32>
+// CHECK: %[[VAL_42:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_43:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_37]], %[[VAL_2]][%[[VAL_42]], %[[VAL_43]]] : memref<16x8xf32>
+// CHECK: %[[VAL_44:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_45:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_38]], %[[VAL_2]][%[[VAL_44]], %[[VAL_45]]] : memref<16x8xf32>
+// CHECK: %[[VAL_46:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_47:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_39]], %[[VAL_2]][%[[VAL_46]], %[[VAL_47]]] : memref<16x8xf32>
+// CHECK: return
+// CHECK: }
+ linalg.matmul ins(%A, %B: memref<16x4xf32>, memref<4x8xf32>)
+ outs(%C: memref<16x8xf32>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: func.func @matmul_16x8x16xf16_global
+func.func @matmul_16x8x16xf16_global(
+ %A: memref<16x16xf16>, %B: memref<16x8xf16>, %C: memref<16x8xf16>) {
+
+ // CHECK-COUNT-8: memref.load {{.*}} : memref<16x16xf16>
+ // CHECK-COUNT-8: vector.insert {{.*}} : f16 into vector<4x2xf16>
+ // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+ // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
+ // CHECK-COUNT-4: memref.load {{.*}} : memref<16x8xf16>
+ // CHECK-COUNT-4: vector.insert {{.*}} : f16 into vector<2x2xf16>
+ //
+ // CHECK: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 16]}
+ // CHECK-SAME: : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ //
+ // CHECK-COUNT-4: vector.extract %{{.*}} : vector<2x2xf16>
+ // CHECK-COUNT-4: memref.store %{{.*}} : memref<16x8xf16>
+ linalg.matmul ins(%A, %B: memref<16x16xf16>, memref<16x8xf16>)
+ outs(%C: memref<16x8xf16>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
new file mode 100644
index 0000000000000..6788ccea3a222
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm80_tests:
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
new file mode 100644
index 0000000000000..593e4c317e0e4
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir
@@ -0,0 +1,239 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -gpu-kernel-outlining \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-math-to-llvm \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -convert-arith-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -gpu-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x16xf16>
+!rhs_memref_type = memref<16x8xf16>
+!res_memref_type = memref<16x8xf16>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f16 {
+ %r = arith.index_cast %ridx : index to i32
+ %c = arith.index_cast %cidx : index to i32
+ %strideC = arith.index_cast %strideCidx : index to i32
+ %2 = arith.muli %r, %strideC : i32
+ %3 = arith.addi %c, %2 : i32
+ %4 = arith.sitofp %3 : i32 to f16
+ %factor = arith.constant 64.0 : f16
+ %5 = arith.divf %4, %factor : f16
+ return %5: f16
+}
+
+func.func @print_lhs_as_memref_32(%lhs: !lhs_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %lhs, %c0 : !lhs_memref_type
+ %N = memref.dim %lhs, %c1 : !lhs_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %lhs[%m, %n] : !lhs_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @print_rhs_as_memref_32(%rhs: !rhs_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %rhs, %c0 : !rhs_memref_type
+ %N = memref.dim %rhs, %c1 : !rhs_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %rhs[%m, %n] : !rhs_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @print_res_as_memref_32(%res: !res_memref_type) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %tmp_alloc = memref.alloc(%M, %N) : memref<?x?xf32>
+ scf.for %m = %c0 to %M step %c1 {
+ scf.for %n = %c0 to %N step %c1 {
+ %f16 = memref.load %res[%m, %n] : !res_memref_type
+ %f32 = arith.extf %f16 : f16 to f32
+ memref.store %f32, %tmp_alloc[%m, %n] : memref<?x?xf32>
+ }
+ }
+ %casted = memref.cast %tmp_alloc : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%casted) : (memref<*xf32>) -> ()
+ memref.dealloc %tmp_alloc : memref<?x?xf32>
+ return
+}
+
+func.func @main() {
+ %lhs = memref.alloc() : !lhs_memref_type
+ %rhs = memref.alloc() : !rhs_memref_type
+ %res = memref.alloc() : !res_memref_type
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+ %f1 = arith.constant 1.0e+00 : f16
+ %f0 = arith.constant 0.0e+00 : f16
+ %c32 = arith.constant 32 : index
+
+ // Intialize the lhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %K step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f16
+ memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %K step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+ memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f16
+ memref.store %idx, %res[%r, %c] : !res_memref_type
+ }
+ }
+
+ %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf16>
+ %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf16>
+ %ures = memref.cast %res : !res_memref_type to memref<*xf16>
+ gpu.host_register %ulhs : memref<*xf16>
+ gpu.host_register %urhs : memref<*xf16>
+ gpu.host_register %ures : memref<*xf16>
+
+ // Print the memrefs before computation.
+ call @print_lhs_as_memref_32(%lhs) : (!lhs_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375, 0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375, 0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375, 0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375, 0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938, 1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938, 1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938, 1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938, 1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438],
+ // CHECK: [2, 2.01562, 2.03125, 2.04688, 2.0625, 2.07812, 2.09375, 2.10938, 2.125, 2.14062, 2.15625, 2.17188, 2.1875, 2.20312, 2.21875, 2.23438],
+ // CHECK: [2.25, 2.26562, 2.28125, 2.29688, 2.3125, 2.32812, 2.34375, 2.35938, 2.375, 2.39062, 2.40625, 2.42188, 2.4375, 2.45312, 2.46875, 2.48438],
+ // CHECK: [2.5, 2.51562, 2.53125, 2.54688, 2.5625, 2.57812, 2.59375, 2.60938, 2.625, 2.64062, 2.65625, 2.67188, 2.6875, 2.70312, 2.71875, 2.73438],
+ // CHECK: [2.75, 2.76562, 2.78125, 2.79688, 2.8125, 2.82812, 2.84375, 2.85938, 2.875, 2.89062, 2.90625, 2.92188, 2.9375, 2.95312, 2.96875, 2.98438],
+ // CHECK: [3, 3.01562, 3.03125, 3.04688, 3.0625, 3.07812, 3.09375, 3.10938, 3.125, 3.14062, 3.15625, 3.17188, 3.1875, 3.20312, 3.21875, 3.23438],
+ // CHECK: [3.25, 3.26562, 3.28125, 3.29688, 3.3125, 3.32812, 3.34375, 3.35938, 3.375, 3.39062, 3.40625, 3.42188, 3.4375, 3.45312, 3.46875, 3.48438],
+ // CHECK: [3.5, 3.51562, 3.53125, 3.54688, 3.5625, 3.57812, 3.59375, 3.60938, 3.625, 3.64062, 3.65625, 3.67188, 3.6875, 3.70312, 3.71875, 3.73438],
+ // CHECK: [3.75, 3.76562, 3.78125, 3.79688, 3.8125, 3.82812, 3.84375, 3.85938, 3.875, 3.89062, 3.90625, 3.92188, 3.9375, 3.95312, 3.96875, 3.98438]
+
+ call @print_rhs_as_memref_32(%rhs) : (!rhs_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375],
+ // CHECK: [0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375],
+ // CHECK: [0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375],
+ // CHECK: [0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375],
+ // CHECK: [0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938],
+ // CHECK: [1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938],
+ // CHECK: [1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938],
+ // CHECK: [1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938],
+ // CHECK: [1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438]
+
+ call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+ // CHECK: [0, 0.015625, 0.03125, 0.046875, 0.0625, 0.078125, 0.09375, 0.109375],
+ // CHECK: [0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375],
+ // CHECK: [0.25, 0.265625, 0.28125, 0.296875, 0.3125, 0.328125, 0.34375, 0.359375],
+ // CHECK: [0.375, 0.390625, 0.40625, 0.421875, 0.4375, 0.453125, 0.46875, 0.484375],
+ // CHECK: [0.5, 0.515625, 0.53125, 0.546875, 0.5625, 0.578125, 0.59375, 0.609375],
+ // CHECK: [0.625, 0.640625, 0.65625, 0.671875, 0.6875, 0.703125, 0.71875, 0.734375],
+ // CHECK: [0.75, 0.765625, 0.78125, 0.796875, 0.8125, 0.828125, 0.84375, 0.859375],
+ // CHECK: [0.875, 0.890625, 0.90625, 0.921875, 0.9375, 0.953125, 0.96875, 0.984375],
+ // CHECK: [1, 1.01562, 1.03125, 1.04688, 1.0625, 1.07812, 1.09375, 1.10938],
+ // CHECK: [1.125, 1.14062, 1.15625, 1.17188, 1.1875, 1.20312, 1.21875, 1.23438],
+ // CHECK: [1.25, 1.26562, 1.28125, 1.29688, 1.3125, 1.32812, 1.34375, 1.35938],
+ // CHECK: [1.375, 1.39062, 1.40625, 1.42188, 1.4375, 1.45312, 1.46875, 1.48438],
+ // CHECK: [1.5, 1.51562, 1.53125, 1.54688, 1.5625, 1.57812, 1.59375, 1.60938],
+ // CHECK: [1.625, 1.64062, 1.65625, 1.67188, 1.6875, 1.70312, 1.71875, 1.73438],
+ // CHECK: [1.75, 1.76562, 1.78125, 1.79688, 1.8125, 1.82812, 1.84375, 1.85938],
+ // CHECK: [1.875, 1.89062, 1.90625, 1.92188, 1.9375, 1.95312, 1.96875, 1.98438]
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+ linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+ outs(%res: !res_memref_type)
+
+ gpu.terminator
+ }
+
+
+ // Print the result memref after computation.
+ // This has been verified against other f16 CUDA implementations.
+ call @print_res_as_memref_32(%res) : (!res_memref_type) -> ()
+ // CHECK: [2.42188, 2.4668, 2.51172, 2.55664, 2.60156, 2.64648, 2.69141, 2.73633],
+ // CHECK: [6.29688, 6.40625, 6.51172, 6.61719, 6.72656, 6.83594, 6.94141, 7.04688],
+ // CHECK: [10.1719, 10.3438, 10.5156, 10.6797, 10.8516, 11.0234, 11.1875, 11.3594],
+ // CHECK: [14.0469, 14.2812, 14.5156, 14.7422, 14.9766, 15.2109, 15.4375, 15.6719],
+ // CHECK: [17.9219, 18.2188, 18.5156, 18.8125, 19.0938, 19.3906, 19.6875, 19.9844],
+ // CHECK: [21.7969, 22.1562, 22.5156, 22.875, 23.2188, 23.5781, 23.9375, 24.2969],
+ // CHECK: [25.6719, 26.0938, 26.5156, 26.9375, 27.3438, 27.7656, 28.1875, 28.6094],
+ // CHECK: [29.5469, 30.0312, 30.5156, 31, 31.4688, 31.9531, 32.4375, 32.9375],
+ // CHECK: [33.4375, 33.9688, 34.5, 35.0625, 35.5938, 36.1562, 36.6875, 37.25],
+ // CHECK: [37.3125, 37.9062, 38.5, 39.125, 39.7188, 40.3438, 40.9375, 41.5625],
+ // CHECK: [41.1875, 41.8438, 42.5, 43.1875, 43.8438, 44.5312, 45.1875, 45.875],
+ // CHECK: [45.0625, 45.7812, 46.5, 47.25, 47.9688, 48.7188, 49.4375, 50.1875],
+ // CHECK: [48.9375, 49.7188, 50.5, 51.3125, 52.0938, 52.9062, 53.6875, 54.5],
+ // CHECK: [52.8125, 53.6562, 54.5, 55.375, 56.2188, 57.0938, 57.9375, 58.8125],
+ // CHECK: [56.6875, 57.5938, 58.5, 59.4375, 60.3438, 61.2812, 62.1875, 63.125],
+ // CHECK: [60.5625, 61.5312, 62.5, 63.5, 64.5, 65.4375, 66.4375, 67.4375]
+
+ return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
new file mode 100644
index 0000000000000..f71e04a500698
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir
@@ -0,0 +1,178 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC
+
+// CHECK-MMA-SYNC-LABEL: func @main() {
+// CHECK-MMA-SYNC: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled}
+// CHECK-MMA-SYNC-SAME: : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+
+// Tested to run locally in 1.7s.
+
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -gpu-kernel-outlining \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-math-to-llvm \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -convert-arith-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -gpu-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x4xf32>
+!rhs_memref_type = memref<4x8xf32>
+!res_memref_type = memref<16x8xf32>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f32 {
+ %r = arith.index_cast %ridx : index to i32
+ %c = arith.index_cast %cidx : index to i32
+ %strideC = arith.index_cast %strideCidx : index to i32
+ %2 = arith.muli %r, %strideC : i32
+ %3 = arith.addi %c, %2 : i32
+ %4 = arith.sitofp %3 : i32 to f32
+ return %4: f32
+}
+
+func.func @main() {
+ %lhs = memref.alloc() : !lhs_memref_type
+ %rhs = memref.alloc() : !rhs_memref_type
+ %res = memref.alloc() : !res_memref_type
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+ %f1 = arith.constant 1.0e+00 : f32
+ %f0 = arith.constant 0.0e+00 : f32
+ %c32 = arith.constant 32 : index
+
+ // Intialize the lhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %K step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f32
+ memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %K step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+ memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+ memref.store %idx, %res[%r, %c] : !res_memref_type
+ }
+ }
+
+ %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf32>
+ %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf32>
+ %ures = memref.cast %res : !res_memref_type to memref<*xf32>
+ gpu.host_register %ulhs : memref<*xf32>
+ gpu.host_register %urhs : memref<*xf32>
+ gpu.host_register %ures : memref<*xf32>
+
+ // Print the memrefs before computation.
+ call @printMemrefF32(%ulhs) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3],
+ // CHECK: [4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11],
+ // CHECK: [12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19],
+ // CHECK: [20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27],
+ // CHECK: [28, 29, 30, 31],
+ // CHECK: [32, 33, 34, 35],
+ // CHECK: [36, 37, 38, 39],
+ // CHECK: [40, 41, 42, 43],
+ // CHECK: [44, 45, 46, 47],
+ // CHECK: [48, 49, 50, 51],
+ // CHECK: [52, 53, 54, 55],
+ // CHECK: [56, 57, 58, 59],
+ // CHECK: [60, 61, 62, 63]
+
+ call @printMemrefF32(%urhs) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11, 12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19, 20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27, 28, 29, 30, 31]
+
+ call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11, 12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19, 20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27, 28, 29, 30, 31],
+ // CHECK: [32, 33, 34, 35, 36, 37, 38, 39],
+ // CHECK: [40, 41, 42, 43, 44, 45, 46, 47],
+ // CHECK: [48, 49, 50, 51, 52, 53, 54, 55],
+ // CHECK: [56, 57, 58, 59, 60, 61, 62, 63],
+ // CHECK: [64, 65, 66, 67, 68, 69, 70, 71],
+ // CHECK: [72, 73, 74, 75, 76, 77, 78, 79],
+ // CHECK: [80, 81, 82, 83, 84, 85, 86, 87],
+ // CHECK: [88, 89, 90, 91, 92, 93, 94, 95],
+ // CHECK: [96, 97, 98, 99, 100, 101, 102, 103],
+ // CHECK: [104, 105, 106, 107, 108, 109, 110, 111],
+ // CHECK: [112, 113, 114, 115, 116, 117, 118, 119],
+ // CHECK: [120, 121, 122, 123, 124, 125, 126, 127]
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+ linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+ outs(%res: !res_memref_type)
+
+ gpu.terminator
+ }
+
+
+ // Print the result memref after computation.
+ call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+
+ // CHECK: [112, 119, 126, 133, 140, 147, 154, 161],
+ // CHECK: [312, 335, 358, 381, 404, 427, 450, 473],
+ // CHECK: [512, 551, 590, 629, 668, 707, 746, 785],
+ // CHECK: [712, 767, 822, 877, 932, 987, 1042, 1097],
+ // CHECK: [912, 983, 1054, 1125, 1196, 1267, 1338, 1409],
+ // CHECK: [1112, 1199, 1286, 1373, 1460, 1547, 1634, 1721],
+ // CHECK: [1312, 1415, 1518, 1621, 1724, 1827, 1930, 2033],
+ // CHECK: [1512, 1631, 1750, 1869, 1988, 2107, 2226, 2345],
+ // CHECK: [1712, 1847, 1982, 2117, 2252, 2387, 2522, 2657],
+ // CHECK: [1912, 2063, 2214, 2365, 2516, 2667, 2818, 2969],
+ // CHECK: [2112, 2279, 2446, 2613, 2780, 2947, 3114, 3281],
+ // CHECK: [2312, 2495, 2678, 2861, 3044, 3227, 3410, 3593],
+ // CHECK: [2512, 2711, 2910, 3109, 3308, 3507, 3706, 3905],
+ // CHECK: [2712, 2927, 3142, 3357, 3572, 3787, 4002, 4217],
+ // CHECK: [2912, 3143, 3374, 3605, 3836, 4067, 4298, 4529],
+ // CHECK: [3112, 3359, 3606, 3853, 4100, 4347, 4594, 4841]
+
+ return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 479e04e7df686..72c565b729d55 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2758,6 +2758,64 @@ cc_library(
],
)
+cc_library(
+ name = "NVGPUTransformOps",
+ srcs = glob([
+ "lib/Dialect/NVGPU/TransformOps/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/NVGPU/TransformOps/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ArithDialect",
+ ":ArithUtils",
+ ":AffineDialect",
+ ":DialectUtils",
+ ":GPUDialect",
+ ":IR",
+ ":LinalgDialect",
+ ":MemRefDialect",
+ ":NVGPUDialect",
+ ":NVGPUTransformOpsIncGen",
+ ":Support",
+ ":TransformDialect",
+ ":VectorDialect",
+ "//llvm:Support",
+ ],
+)
+
+td_library(
+ name = "NVGPUTransformOpsTdFiles",
+ srcs = glob([
+ "include/mlir/Dialect/NVGPU/TransformOps/*.td",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":TransformDialectTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "NVGPUTransformOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td",
+ deps = [
+ ":NVGPUTransformOpsTdFiles",
+ ],
+)
+
cc_library(
name = "NVGPUUtils",
srcs = ["lib/Dialect/NVGPU/Utils/MMAUtils.cpp"],
@@ -7685,6 +7743,7 @@ cc_library(
":NVGPUPassIncGen",
":NVGPUToNVVM",
":NVGPUTransforms",
+ ":NVGPUTransformOps",
":NVVMDialect",
":OpenACCDialect",
":OpenMPDialect",
More information about the Mlir-commits
mailing list