[Mlir-commits] [mlir] ea2ed80 - [mlir][nvgpu] NFC - move NVGPU conversion helpers to NvGpu utils library

Christopher Bate llvmlistbot at llvm.org
Wed Oct 5 19:21:37 PDT 2022


Author: Christopher Bate
Date: 2022-10-05T20:21:27-06:00
New Revision: ea2ed80e6dac21bf19d87de8bef1ec01b00dbb8d

URL: https://github.com/llvm/llvm-project/commit/ea2ed80e6dac21bf19d87de8bef1ec01b00dbb8d
DIFF: https://github.com/llvm/llvm-project/commit/ea2ed80e6dac21bf19d87de8bef1ec01b00dbb8d.diff

LOG: [mlir][nvgpu] NFC - move NVGPU conversion helpers to NvGpu utils library

The ConvertVectorToGpu pass implementation contained a small private
support library for performing various calculations during conversion
between `vector` and `nvgpu.mma.sync` and `nvgpu.ldmatrix` operations.
The support library is moved under `Dialect/NVGPU/Utils` because the
functions have wider utility. Some documentation comments are added or
improved.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135303

Added: 
    mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
    mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt
    mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp

Modified: 
    mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/NVGPU/CMakeLists.txt

Removed: 
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
    mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
similarity index 69%
rename from mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
rename to mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
index aee429edbad7c..699e9fdb25a0b 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h
+++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h
@@ -1,4 +1,4 @@
-//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
+//===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,24 +6,30 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file provides utilities to assist in the lowering of Vector operations
-// to GPU dialect MMA operations.
+// This file provides utilities to assist in the lowering of other dialects
+// (e.g. Vector) to `nvgpu.mma.*` dialect operations.
 //
 //===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
-#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
-
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
+#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
+
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Types.h"
 
 namespace mlir {
+namespace vector {
+enum class IteratorType : uint32_t;
+class ContractionOp;
+} // namespace vector
+
+namespace NVVM {
+enum class MMALayout : uint32_t;
+} // namespace NVVM
+
 namespace nvgpu {
 
+/// Represents the role of an operand in an MMA instruction:
+/// `result := matmul(A, B) + C`
 enum class MatMulOperandRole : int32_t { A = 0, B, C };
 
 /// Collects information about a warp-level matrix operand represented by a
@@ -33,8 +39,10 @@ struct WarpMatrixInfo {
   MatMulOperandRole operandRole;
 };
 
-/// Given an op that operates on a VectorType representing a warp-level matrix
-/// operand, the function returns a struct containing relevant type information.
+/// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the
+/// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or
+/// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result.
+/// Otherwise, return failure.
 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
 
 /// Returns the number of bits in a single tile row. It is either 128, 256, or
@@ -67,6 +75,8 @@ FailureOr<AffineMap>
 getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
                                   const WarpMatrixInfo &fragmentType);
 
+/// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
+/// `nvvm.ldmatrix`.
 struct LdMatrixParams {
   VectorType fragmentType;
   bool isAccum;
@@ -75,6 +85,8 @@ struct LdMatrixParams {
   NVVM::MMALayout targetLayout;
 };
 
+/// Given `type` that contains info for a warp-matrix operand and whether or not
+/// the load is a transposed load, return the LdMatrixParams.
 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
                                             bool transpose);
 /// Returns an AffineMap which maps a single dimension representing the laneId
@@ -84,8 +96,10 @@ FailureOr<AffineMap>
 getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
                                const LdMatrixParams &params);
 
-// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
-// to MMA matmul.
+/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be
+/// converted to `nvgpu.mma.sync`. This specific form is meant to indicate that
+/// the vector operands are organized such that the reduction dimension is
+/// contiguous.
 struct PrepareContractToGPUMMASync
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
@@ -97,4 +111,4 @@ struct PrepareContractToGPUMMASync
 } // namespace nvgpu
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
+#endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H

diff  --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
index b29551a9348fd..d9dbe349a0d63 100644
--- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_conversion_library(MLIRVectorToGPU
   VectorToGPU.cpp
-  NvGpuSupport.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU
@@ -14,6 +13,7 @@ add_mlir_conversion_library(MLIRVectorToGPU
   MLIRLLVMDialect
   MLIRMemRefDialect
   MLIRNVGPUDialect
+  MLIRNVGPUUtils
   MLIRTransforms
   MLIRVectorDialect
   MLIRVectorUtils

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 9574035f70aaa..f4528b178e656 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -10,16 +10,17 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <type_traits>
-
-#include "NvGpuSupport.h"
 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
 
+#include <type_traits>
+
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"

diff  --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
index 9f57627c321fb..7117520599fa6 100644
--- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(Utils)
 add_subdirectory(Transforms)

diff  --git a/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt
new file mode 100644
index 0000000000000..181784379b134
--- /dev/null
+++ b/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRNVGPUUtils
+  MMAUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRAffineDialect
+  MLIRLLVMDialect
+  MLIRNVGPUDialect
+  MLIRNVVMDialect
+  MLIRVectorDialect
+  MLIRIR
+  )

diff  --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
similarity index 88%
rename from mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
rename to mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 5bd1757a87d1e..18fc4e600db9e 100644
--- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -1,37 +1,32 @@
-//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
+//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file provides utilities to assist in the lowering of Vector operations
-// to NvGPU dialect MMA operations.
-//
-//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
 
-#include "NvGpuSupport.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
-namespace mlir {
-namespace nvgpu {
-namespace {
+using namespace mlir;
+using namespace mlir::nvgpu;
 
 /// There are always 4 threads per [128|256|512] bit row.
-constexpr int64_t kThreadsPerRow = 4;
+static constexpr int64_t kThreadsPerRow = 4;
+static constexpr int64_t kNumRowsPerTile = 8;
 
-constexpr int64_t kNumRowsPerTile = 8;
-
-bool isAccumulatorOrResult(MatMulOperandRole operandType) {
+static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
   return operandType == MatMulOperandRole::C;
 }
 
 /// Returns the number of registers which compose a matrix fragment held by a
 /// single thread.
-int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
+static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
   int64_t lineSize = inferTileWidthInBits(type);
   auto shape = type.vectorType.getShape();
   return (shape[0] / kNumRowsPerTile) *
@@ -41,17 +36,16 @@ int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
 
 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
 /// operand shape.
-std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
-                                    Type elementType, int64_t lineSizeBits) {
+static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
+                                           Type elementType,
+                                           int64_t lineSizeBits) {
   // For each 8x128bit square, a thread is responsible for one 32bit register.
   return {operandShape[0] / kNumRowsPerTile,
           (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
               lineSizeBits};
 }
 
-} // namespace
-
-FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
+FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
   WarpMatrixInfo info;
 
   // Determine the vector type.
@@ -84,7 +78,7 @@ FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
   return info;
 }
 
-int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
+int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
   bool isAcc = isAccumulatorOrResult(type.operandRole);
   Type elType = type.vectorType.getElementType();
   if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
@@ -97,7 +91,7 @@ int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
 }
 
 FailureOr<FragmentElementInfo>
-getMmaSyncRegisterType(const WarpMatrixInfo &type) {
+nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
   MLIRContext *ctx = type.vectorType.getContext();
   const bool isAccum = isAccumulatorOrResult(type.operandRole);
 
@@ -170,8 +164,8 @@ static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
 }
 
 FailureOr<AffineMap>
-getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
-                                  const WarpMatrixInfo &fragmentType) {
+nvgpu::getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
+                                         const WarpMatrixInfo &fragmentType) {
   Type elementType = fragmentType.vectorType.getElementType();
   ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
   FailureOr<nvgpu::FragmentElementInfo> regInfo =
@@ -205,8 +199,8 @@ getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
                       (logicalValueIdDim % elementsPerRegister)});
 }
 
-FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
-                                                   bool transpose) {
+FailureOr<nvgpu::LdMatrixParams>
+nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
   LdMatrixParams params;
   Type elType = type.vectorType.getElementType();
   params.fragmentType = type.vectorType;
@@ -235,8 +229,8 @@ FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
 }
 
 FailureOr<AffineMap>
-getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
-                               const LdMatrixParams &params) {
+nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
+                                      const LdMatrixParams &params) {
   // One thread per 128b row.
   const int64_t kNumThreadsPerTile = kNumRowsPerTile;
   const int bitsPerElement = static_cast<int>(
@@ -273,9 +267,8 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
   return failure();
 }
 
-LogicalResult
-PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
-                                             PatternRewriter &rewriter) const {
+LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   Value lhs = op.getLhs();
   Value rhs = op.getRhs();
@@ -330,6 +323,3 @@ PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
       op.getIteratorTypes());
   return success();
 }
-
-} // namespace nvgpu
-} // namespace mlir


        


More information about the Mlir-commits mailing list