[Mlir-commits] [mlir] 40deed4 - [mlir][Transform] Introduce nvgpu transform extensions
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jun 26 09:21:59 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-26T16:21:28Z
New Revision: 40deed40ae77ba22f7c72693903752ab6bfeb4e7
URL: https://github.com/llvm/llvm-project/commit/40deed40ae77ba22f7c72693903752ab6bfeb4e7
DIFF: https://github.com/llvm/llvm-project/commit/40deed40ae77ba22f7c72693903752ab6bfeb4e7.diff
LOG: [mlir][Transform] Introduce nvgpu transform extensions
Mapping to NVGPU operations such as mma.sync with mixed precision and ldmatrix with transposes and
various data types involves complex matchings from low-level IR.
This is akin to raising complex patterns after unnecessarily having lost structural information.
To avoid such unnecessary complexity, introduce a direct mapping step from a matmul on memrefs
to distributed NVGPU vector abstractions.
In this context, mapping to specific mma.sync operations is trivial and consists in simply
translating the documentation into indexing expressions.
Correctness is demonstrated with an end-to-end integration test.
Differential Revision: https://reviews.llvm.org/D153420
Added:
mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir
Modified:
mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/NVGPU/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
index d48de9df4b71b..49a07c997a30c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/NVGPU/CMakeLists.txt
@@ -1,4 +1,5 @@
add_subdirectory(IR)
+add_subdirectory(TransformOps)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name NVGPU)
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..d75ae3dd5d017
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS NVGPUTransformOps.td)
+mlir_tablegen(NVGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(NVGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRNVGPUTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
new file mode 100644
index 0000000000000..0c7b9d865aa24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h
@@ -0,0 +1,43 @@
+//===- NVGPUTransformOps.h - NVGPU transform ops ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+namespace mlir {
+namespace transform {
+class TransformHandleTypeInterface;
+} // namespace transform
+} // namespace mlir
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+class LinalgOp;
+} // namespace linalg
+
+namespace nvgpu {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace nvgpu
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// NVGPU Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc"
+
+#endif // MLIR_DIALECT_NVGPU_TRANSFORMOPS_NVGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
new file mode 100644
index 0000000000000..168a445b62ccf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -0,0 +1,51 @@
+//===- NVGPUTransformOps.td - NVGPU transform ops ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef NVGPU_TRANSFORM_OPS
+#define NVGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+def RewriteMatmulAsMmaSyncOp :
+ Op<Transform_Dialect, "nvgpu.rewrite_matmul_as_mma_sync",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Rewrite a matmul operation on memref to an mma.sync operation on vectors.
+
+ Memory copies with the required access patterns are automatically inserted.
+ Operations that do not have a 1-1 mapping to mma.sync operations are left
+ unchanged.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp linalgOp,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+#endif // NVGPU_TRANSFORM_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index db15dff136cd1..106d5af5cfb7d 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,6 +55,7 @@
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
@@ -137,6 +138,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
transform::registerPDLExtension(registry);
diff --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
index 7117520599fa6..63b4d8b99f53f 100644
--- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(Utils)
+add_subdirectory(TransformOps)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..973e5268389fd
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_dialect_library(MLIRNVGPUTransformOps
+ NVGPUTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU/TransformOps
+
+ DEPENDS
+ MLIRNVGPUTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffineDialect
+ MLIRArithDialect
+ MLIRIR
+ MLIRLinalgDialect
+ MLIRNVGPUDialect
+ MLIRParser
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRVectorTransforms
+ )
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
new file mode 100644
index 0000000000000..887b375aac788
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -0,0 +1,401 @@
+//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::nvgpu;
+using namespace mlir::transform;
+
+#define DEBUG_TYPE "nvgpu-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+//===----------------------------------------------------------------------===//
+// RewriteMatmulAsMmaSyncOp
+//===----------------------------------------------------------------------===//
+
+/// Helper struct to encode a pair of row/column indexings in the form of
+/// affine expressions.
+struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
+ RowColIndexing(AffineExpr row, AffineExpr col)
+ : std::pair<AffineExpr, AffineExpr>(row, col) {}
+
+ AffineExpr row() const { return first; };
+ AffineExpr col() const { return second; };
+
+ void print(llvm::raw_ostream &os) const {
+ os << "- indexing: " << first << ", " << second;
+ }
+};
+
+/// Helper struct to provide a simple mapping from matmul operations to the
+/// corresponding mma.sync operation. This is constrained to the case where the
+/// matmul matches the mma.sync operation 1-1.
+struct MmaSyncBuilder {
+ MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
+ : b(b), loc(loc), laneId(laneId) {}
+
+ using IndexCalculator =
+ std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
+
+ /// Create the mma.sync operation corresponding to `linalgOp` along with all
+ /// the supporting load/store and vector operations.
+ FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
+
+private:
+ struct MmaSyncInfo {
+ std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
+ std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
+ vectorShapes;
+ SmallVector<int64_t> mmaShape;
+ bool tf32Enabled;
+ };
+
+ /// Return the specific index calculator for the given `linalgOp` or failure
+ /// if the op is not supported. This is the toplevel switch that should just
+ /// be Tablegen'd in the future.
+ FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
+ TypeRange elementalTypes);
+
+ //===--------------------------------------------------------------------===//
+ // Instruction-specific row, column indexing expression builders.
+ // These should all be declaratively specified via Tablegen in the future.
+ // The Tablegen specification should be as straightforward as possible to
+ // only model the existing size and type combinations.
+ //===--------------------------------------------------------------------===//
+ //
+ // TODO: Tablegen all this.
+ //===--------------------------------------------------------------------===//
+ // m16n8k4 tf32 case.
+ //===--------------------------------------------------------------------===//
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = groupID for a0
+ /// groupID + 8 for a1
+ /// col = threadIDInGroup
+ static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{groupID, threadIDInGroup},
+ RowColIndexing{groupID + 8, threadIDInGroup}};
+ }
+
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = threadIDInGroup
+ /// col = groupID
+ static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{threadIDInGroup, groupID}};
+ }
+ /// From the NVIDIA doc:
+ /// groupID = %laneid >> 2
+ /// threadIDInGroup = %laneid % 4
+ /// row = groupID for c0 and c1
+ /// groupID + 8 for c2 and c3
+ /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
+ static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
+ auto dim = getAffineDimExpr(0, ctx);
+ AffineExpr groupID = dim.floorDiv(4);
+ AffineExpr threadIDInGroup = dim % 4;
+ return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
+ RowColIndexing{groupID, threadIDInGroup * 2 + 1},
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
+ RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
+ }
+
+ //===--------------------------------------------------------------------===//
+ /// Helper functions to create customizable load and stores operations. The
+ /// specific shapes of each MMA instruction are passed via the
+ /// IndexCalculator callback.
+ //===--------------------------------------------------------------------===//
+ /// Build a list of memref.load operations indexed at `(row, col)` indices
+ /// that make sense for a particular MMA instruction and specified via the
+ /// IndexCalculator callback.
+ SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn);
+
+ /// Perform a distributed load of a vector operand of `vectorShape` for a
+ /// particular MMA instruction whose `(row, col)` indices are specified via
+ /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+ /// data that makes sense for the particular MMA operation.
+ /// The `vectorShape` matches existing NVGPU dialect op specification but
+ /// could also be flattened in the future if needed for simplification.
+ Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn,
+ ArrayRef<int64_t> vectorShape);
+
+ /// Build a list of memref.store operations indexed at `(row, col)` indices
+ /// that make sense for a particular MMA instruction and specified via the
+ /// IndexCalculator callback.
+ SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
+ ValueRange toStore,
+ OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn);
+
+ /// Perform a distributed store of a vector operand of `vectorShape` for a
+ /// particular MMA instruction whose `(row, col)` indices are specified via
+ /// the IndexCalculator callback. Each `laneId` loads the subportion of the
+ /// data that makes sense for the particular MMA operation.
+ /// The `vectorShape` matches existing NVGPU dialect op specification but
+ /// could also be flattened in the future if needed for simplification.
+ SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
+ OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
+
+ OpBuilder &b;
+ Location loc;
+ OpFoldResult laneId;
+};
+
+//===--------------------------------------------------------------------===//
+/// Helper functions to create customizable load and stores operations. The
+/// specific shapes of each MMA instruction are passed via the
+/// IndexCalculator callback.
+//===--------------------------------------------------------------------===//
+
+template <typename ApplyFn, typename ReduceFn>
+static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
+ ReduceFn reduceFn) {
+ VectorType vectorType = vector.getType().cast<VectorType>();
+ auto vectorShape = vectorType.getShape();
+ auto strides = computeStrides(vectorShape);
+ for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
+ auto indices = delinearize(idx, strides);
+ reduceFn(applyFn(vector, idx, indices), idx, indices);
+ }
+}
+
+SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
+ OpFoldResult laneId,
+ Value memref,
+ IndexCalculator indexFn) {
+ auto aff = [&](AffineExpr e) {
+ return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+ };
+ SmallVector<Value> res;
+ SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
+ for (auto indexing : indexings) {
+ Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+ Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+ auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
+ res.push_back(load);
+ }
+ return res;
+}
+
+Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
+ OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
+ IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+ auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
+
+ Type elementType = getElementTypeOrSelf(memref.getType());
+ auto vt = VectorType::get(vectorShape, elementType);
+ Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
+ foreachIndividualVectorElement(
+ res,
+ /*applyFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ return loads[linearIdx];
+ },
+ /*reduceFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ res = b.create<vector::InsertOp>(loc, v, res, indices);
+ });
+
+ return res;
+}
+
+SmallVector<Operation *>
+MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
+ ValueRange toStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn) {
+ auto aff = [&](AffineExpr e) {
+ return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
+ };
+ SmallVector<Operation *> res;
+ for (auto [indexing, val] :
+ llvm::zip_equal(indexFn(b.getContext()), toStore)) {
+ Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
+ Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
+ Operation *store =
+ b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
+ res.push_back(store);
+ }
+ return res;
+}
+
+SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
+ OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
+ Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
+ SmallVector<Value> toStore;
+ toStore.reserve(32);
+ foreachIndividualVectorElement(
+ vectorToStore,
+ /*applyFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
+ },
+ /*reduceFn=*/
+ [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
+ toStore.push_back(v);
+ });
+ return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
+}
+
+static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
+ SmallVector<int64_t>>
+makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+ ArrayRef<int64_t> res) {
+ SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
+ SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
+ SmallVector<int64_t> vres{res.begin(), res.end()};
+ return std::make_tuple(vlhs, vrhs, vres);
+}
+
+FailureOr<MmaSyncBuilder::MmaSyncInfo>
+MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
+ TypeRange elementalTypes) {
+ // TODO: Tablegen all this.
+ Type f32 = b.getF32Type();
+ if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
+ elementalTypes == TypeRange{f32, f32, f32}) {
+ return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
+ &MmaSyncBuilder::m16n8k4tf32Rhs,
+ &MmaSyncBuilder::m16n8k4tf32Res),
+ makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
+ SmallVector<int64_t>{opShape.begin(), opShape.end()},
+ /*tf32Enabled=*/true};
+ }
+ return failure();
+}
+
+FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
+ Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
+ Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
+ Value resMemref = linalgOp.getDpsInitOperand(0)->get();
+ assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected lhs to be a 2D memref");
+ assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected rhs to be a 2D memref");
+ assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
+ "expected res to be a 2D memref");
+
+ int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
+ int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
+ int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
+ Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
+ Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
+ Type resType = getElementTypeOrSelf(resMemref.getType());
+
+ FailureOr<MmaSyncInfo> maybeInfo =
+ getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
+ if (failed(maybeInfo))
+ return failure();
+
+ MmaSyncInfo info = *maybeInfo;
+ auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
+ auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
+ Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
+ lhsIndexFn, lhsShape);
+ Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
+ rhsIndexFn, rhsShape);
+ Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
+ resIndexFn, resShape);
+ res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
+ info.tf32Enabled);
+ buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
+ resShape);
+ return res.getDefiningOp();
+}
+
+DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
+ transform::TransformRewriter &rewriter, LinalgOp linalgOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ bool fail = true;
+ // TODO: more robust detection of matmulOp, with transposes etc.
+ if (auto matmulOp = isa<linalg::MatmulOp>(linalgOp.getOperation())) {
+ Location loc = linalgOp.getLoc();
+ // TODO: more robust computation of laneId, for now assume a single warp.
+ Value laneId = rewriter.create<gpu::ThreadIdOp>(
+ loc, rewriter.getIndexType(), gpu::Dimension::x);
+ if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
+ fail = false;
+ }
+
+ if (fail) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "unsupported target op: " << linalgOp;
+ diag.attachNote(linalgOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ rewriter.eraseOp(linalgOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class NVGPUTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ NVGPUTransformDialectExtension> {
+public:
+ NVGPUTransformDialectExtension() {
+ declareGeneratedDialect<arith::ArithDialect>();
+ declareGeneratedDialect<affine::AffineDialect>();
+ declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<vector::VectorDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
+
+void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<NVGPUTransformDialectExtension>();
+}
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
new file mode 100644
index 0000000000000..55ff52bb5190a
--- /dev/null
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+// CHECK: #[[$div4:.*]] = affine_map<()[s0] -> (s0 floordiv 4)>
+// CHECK: #[[$mod4:.*]] = affine_map<()[s0] -> (s0 mod 4)>
+// CHECK: #[[$div4p8:.*]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
+// CHECK: #[[$map3:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
+// CHECK: #[[$map4:.*]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 1)>
+
+// CHECK-LABEL: func.func @matmul_16x8x4xf32_global
+func.func @matmul_16x8x4xf32_global(
+ %A: memref<16x4xf32>, %B: memref<4x8xf32>, %C: memref<16x8xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<16x4xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<4x8xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<16x8xf32>) {
+
+// CHECK: %[[TIDX:.*]] = gpu.thread_id x
+// CHECK: %[[VAL_4:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_5:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]][%[[VAL_4]], %[[VAL_5]]] : memref<16x4xf32>
+// CHECK: %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
+// CHECK: %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
+// CHECK: %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
+// CHECK: %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
+//
+// CHECK: %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
+// CHECK: %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
+// CHECK: %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
+//
+// CHECK: %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_19:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_2]][%[[VAL_18]], %[[VAL_19]]] : memref<16x8xf32>
+// CHECK: %[[VAL_21:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_22:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_2]][%[[VAL_21]], %[[VAL_22]]] : memref<16x8xf32>
+// CHECK: %[[VAL_24:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_25:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_2]][%[[VAL_24]], %[[VAL_25]]] : memref<16x8xf32>
+// CHECK: %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
+// CHECK: %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
+// CHECK: %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[VAL_29]], %[[VAL_33]] [1, 1] : f32 into vector<2x2xf32>
+//
+// CHECK: %[[VAL_35:.*]] = nvgpu.mma.sync(%[[LHS]], %[[RHS]], %[[RES]]) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+//
+// CHECK: %[[VAL_36:.*]] = vector.extract %[[VAL_35]][0, 0] : vector<2x2xf32>
+// CHECK: %[[VAL_37:.*]] = vector.extract %[[VAL_35]][0, 1] : vector<2x2xf32>
+// CHECK: %[[VAL_38:.*]] = vector.extract %[[VAL_35]][1, 0] : vector<2x2xf32>
+// CHECK: %[[VAL_39:.*]] = vector.extract %[[VAL_35]][1, 1] : vector<2x2xf32>
+// CHECK: %[[VAL_40:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_41:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_36]], %[[VAL_2]][%[[VAL_40]], %[[VAL_41]]] : memref<16x8xf32>
+// CHECK: %[[VAL_42:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
+// CHECK: %[[VAL_43:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_37]], %[[VAL_2]][%[[VAL_42]], %[[VAL_43]]] : memref<16x8xf32>
+// CHECK: %[[VAL_44:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_45:.*]] = affine.apply #[[$map3]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_38]], %[[VAL_2]][%[[VAL_44]], %[[VAL_45]]] : memref<16x8xf32>
+// CHECK: %[[VAL_46:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
+// CHECK: %[[VAL_47:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
+// CHECK: memref.store %[[VAL_39]], %[[VAL_2]][%[[VAL_46]], %[[VAL_47]]] : memref<16x8xf32>
+// CHECK: return
+// CHECK: }
+ linalg.matmul ins(%A, %B: memref<16x4xf32>, memref<4x8xf32>)
+ outs(%C: memref<16x8xf32>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir
new file mode 100644
index 0000000000000..f35087f4b9c1e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/transform-mma-sync-matmul-f32.mlir
@@ -0,0 +1,178 @@
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: | FileCheck %s --check-prefix=CHECK-MMA-SYNC
+
+// CHECK-MMA-SYNC-LABEL: func @main() {
+// CHECK-MMA-SYNC: nvgpu.mma.sync(%{{.*}}) {mmaShape = [16, 8, 4], tf32Enabled}
+// CHECK-MMA-SYNC-SAME: : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+
+// Tested to run locally in 1.7s.
+
+// RUN: mlir-opt %s \
+// RUN: -test-transform-dialect-interpreter \
+// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -gpu-kernel-outlining \
+// RUN: -convert-scf-to-cf \
+// RUN: -convert-vector-to-llvm \
+// RUN: -convert-math-to-llvm \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -convert-arith-to-llvm \
+// RUN: -finalize-memref-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -canonicalize \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80}))' \
+// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
+// RUN: -gpu-to-llvm \
+// RUN: -convert-func-to-llvm \
+// RUN: -reconcile-unrealized-casts \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+!lhs_memref_type = memref<16x4xf32>
+!rhs_memref_type = memref<4x8xf32>
+!res_memref_type = memref<16x8xf32>
+
+func.func @compute_linspace_val(%ridx: index, %cidx: index, %strideCidx: index) -> f32 {
+ %r = arith.index_cast %ridx : index to i32
+ %c = arith.index_cast %cidx : index to i32
+ %strideC = arith.index_cast %strideCidx : index to i32
+ %2 = arith.muli %r, %strideC : i32
+ %3 = arith.addi %c, %2 : i32
+ %4 = arith.sitofp %3 : i32 to f32
+ return %4: f32
+}
+
+func.func @main() {
+ %lhs = memref.alloc() : !lhs_memref_type
+ %rhs = memref.alloc() : !rhs_memref_type
+ %res = memref.alloc() : !res_memref_type
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %M = memref.dim %res, %c0 : !res_memref_type
+ %N = memref.dim %res, %c1 : !res_memref_type
+ %K = memref.dim %lhs, %c1 : !lhs_memref_type
+
+ %f1 = arith.constant 1.0e+00 : f32
+ %f0 = arith.constant 0.0e+00 : f32
+ %c32 = arith.constant 32 : index
+
+ // Intialize the lhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %K step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %K) : (index, index, index) -> f32
+ memref.store %idx, %lhs[%r, %c] : !lhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %K step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+ memref.store %idx, %rhs[%r, %c] : !rhs_memref_type
+ }
+ }
+ // Intialize the rhs matrix with a linspace function.
+ scf.for %r = %c0 to %M step %c1 {
+ scf.for %c = %c0 to %N step %c1 {
+ %idx = func.call @compute_linspace_val(%r, %c, %N) : (index, index, index) -> f32
+ memref.store %idx, %res[%r, %c] : !res_memref_type
+ }
+ }
+
+ %ulhs = memref.cast %lhs : !lhs_memref_type to memref<*xf32>
+ %urhs = memref.cast %rhs : !rhs_memref_type to memref<*xf32>
+ %ures = memref.cast %res : !res_memref_type to memref<*xf32>
+ gpu.host_register %ulhs : memref<*xf32>
+ gpu.host_register %urhs : memref<*xf32>
+ gpu.host_register %ures : memref<*xf32>
+
+ // Print the memrefs before computation.
+ call @printMemrefF32(%ulhs) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3],
+ // CHECK: [4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11],
+ // CHECK: [12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19],
+ // CHECK: [20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27],
+ // CHECK: [28, 29, 30, 31],
+ // CHECK: [32, 33, 34, 35],
+ // CHECK: [36, 37, 38, 39],
+ // CHECK: [40, 41, 42, 43],
+ // CHECK: [44, 45, 46, 47],
+ // CHECK: [48, 49, 50, 51],
+ // CHECK: [52, 53, 54, 55],
+ // CHECK: [56, 57, 58, 59],
+ // CHECK: [60, 61, 62, 63]
+
+ call @printMemrefF32(%urhs) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11, 12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19, 20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27, 28, 29, 30, 31]
+
+ call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+ // CHECK: [0, 1, 2, 3, 4, 5, 6, 7],
+ // CHECK: [8, 9, 10, 11, 12, 13, 14, 15],
+ // CHECK: [16, 17, 18, 19, 20, 21, 22, 23],
+ // CHECK: [24, 25, 26, 27, 28, 29, 30, 31],
+ // CHECK: [32, 33, 34, 35, 36, 37, 38, 39],
+ // CHECK: [40, 41, 42, 43, 44, 45, 46, 47],
+ // CHECK: [48, 49, 50, 51, 52, 53, 54, 55],
+ // CHECK: [56, 57, 58, 59, 60, 61, 62, 63],
+ // CHECK: [64, 65, 66, 67, 68, 69, 70, 71],
+ // CHECK: [72, 73, 74, 75, 76, 77, 78, 79],
+ // CHECK: [80, 81, 82, 83, 84, 85, 86, 87],
+ // CHECK: [88, 89, 90, 91, 92, 93, 94, 95],
+ // CHECK: [96, 97, 98, 99, 100, 101, 102, 103],
+ // CHECK: [104, 105, 106, 107, 108, 109, 110, 111],
+ // CHECK: [112, 113, 114, 115, 116, 117, 118, 119],
+ // CHECK: [120, 121, 122, 123, 124, 125, 126, 127]
+
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+
+ linalg.matmul ins(%lhs, %rhs: !lhs_memref_type, !rhs_memref_type)
+ outs(%res: !res_memref_type)
+
+ gpu.terminator
+ }
+
+
+ // Print the result memref after computation.
+ call @printMemrefF32(%ures) : (memref<*xf32>) -> ()
+
+ // CHECK: [112, 119, 126, 133, 140, 147, 154, 161],
+ // CHECK: [312, 335, 358, 381, 404, 427, 450, 473],
+ // CHECK: [512, 551, 590, 629, 668, 707, 746, 785],
+ // CHECK: [712, 767, 822, 877, 932, 987, 1042, 1097],
+ // CHECK: [912, 983, 1054, 1125, 1196, 1267, 1338, 1409],
+ // CHECK: [1112, 1199, 1286, 1373, 1460, 1547, 1634, 1721],
+ // CHECK: [1312, 1415, 1518, 1621, 1724, 1827, 1930, 2033],
+ // CHECK: [1512, 1631, 1750, 1869, 1988, 2107, 2226, 2345],
+ // CHECK: [1712, 1847, 1982, 2117, 2252, 2387, 2522, 2657],
+ // CHECK: [1912, 2063, 2214, 2365, 2516, 2667, 2818, 2969],
+ // CHECK: [2112, 2279, 2446, 2613, 2780, 2947, 3114, 3281],
+ // CHECK: [2312, 2495, 2678, 2861, 3044, 3227, 3410, 3593],
+ // CHECK: [2512, 2711, 2910, 3109, 3308, 3507, 3706, 3905],
+ // CHECK: [2712, 2927, 3142, 3357, 3572, 3787, 4002, 4217],
+ // CHECK: [2912, 3143, 3374, 3605, 3836, 4067, 4298, 4529],
+ // CHECK: [3112, 3359, 3606, 3853, 4100, 4347, 4594, 4841]
+
+ return
+}
+
+func.func private @printMemrefF32(memref<*xf32>)
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.nvgpu.rewrite_matmul_as_mma_sync %matmul
+ : (!transform.any_op) -> ()
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 479e04e7df686..72c565b729d55 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2758,6 +2758,64 @@ cc_library(
],
)
+cc_library(
+ name = "NVGPUTransformOps",
+ srcs = glob([
+ "lib/Dialect/NVGPU/TransformOps/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/NVGPU/TransformOps/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ArithDialect",
+ ":ArithUtils",
+ ":AffineDialect",
+ ":DialectUtils",
+ ":GPUDialect",
+ ":IR",
+ ":LinalgDialect",
+ ":MemRefDialect",
+ ":NVGPUDialect",
+ ":NVGPUTransformOpsIncGen",
+ ":Support",
+ ":TransformDialect",
+ ":VectorDialect",
+ "//llvm:Support",
+ ],
+)
+
+td_library(
+ name = "NVGPUTransformOpsTdFiles",
+ srcs = glob([
+ "include/mlir/Dialect/NVGPU/TransformOps/*.td",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":TransformDialectTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "NVGPUTransformOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td",
+ deps = [
+ ":NVGPUTransformOpsTdFiles",
+ ],
+)
+
cc_library(
name = "NVGPUUtils",
srcs = ["lib/Dialect/NVGPU/Utils/MMAUtils.cpp"],
@@ -7685,6 +7743,7 @@ cc_library(
":NVGPUPassIncGen",
":NVGPUToNVVM",
":NVGPUTransforms",
+ ":NVGPUTransformOps",
":NVVMDialect",
":OpenACCDialect",
":OpenMPDialect",
More information about the Mlir-commits
mailing list