[Mlir-commits] [mlir] c0321ed - [mlir][gpu] Adding support for transposed mma_load_matrix

Thomas Raoux llvmlistbot at llvm.org
Mon Nov 28 19:36:10 PST 2022


Author: Quinn Dawkins
Date: 2022-11-29T03:35:49Z
New Revision: c0321edc26a7c02b46ed18c8f8bdcaaa9d9ce8a2

URL: https://github.com/llvm/llvm-project/commit/c0321edc26a7c02b46ed18c8f8bdcaaa9d9ce8a2
DIFF: https://github.com/llvm/llvm-project/commit/c0321edc26a7c02b46ed18c8f8bdcaaa9d9ce8a2.diff

LOG: [mlir][gpu] Adding support for transposed mma_load_matrix

Enables transposed gpu.subgroup_mma_load_matrix and updates the lowerings in Vector to GPU and GPU to SPIRV. Needed to enable B transpose matmuls lowering to wmma ops.

Taken over from author: stanley-nod <stanley at nod-labs.com>

Reviewed By: ThomasRaoux, antiagainst

Differential Revision: https://reviews.llvm.org/D138770

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index cadc6858ba020..44684da5baccc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1102,7 +1102,7 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     determined using `indices`. The matrix being loaded into is the result.  The
     `leadDimension` attribute specifies the leading dimension size of the source
     matrix which eventually allows the lowering to determine the size of each
-    row.
+    row.  If the `transpose` attribute is present then the op does a transposed load.
 
     This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
     `gpu.subgroup_mma_compute`.
@@ -1117,7 +1117,8 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
 
   let arguments = (ins Arg<GPU_MMAMemRef, "", [MemRead]>:$srcMemref,
                   Variadic<Index>:$indices,
-                  IndexAttr:$leadDimension);
+                  IndexAttr:$leadDimension,
+                  OptionalAttr<UnitAttr>:$transpose);
 
   let results = (outs GPU_MMAMatrix:$res);
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 963a718dc5f36..47687d9df4536 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -77,6 +77,10 @@ struct WmmaLoadOpToNVVMLowering
     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
       return failure();
 
+    // TODO: Support transposed mma loads.
+    if (subgroupMmaLoadMatrixOp.getTranspose())
+      return failure();
+
     // Get the shape of the MMAMatrix type being returned. The shape will
     // choose which intrinsic this op will be lowered to.
     gpu::MMAMatrixType retType =

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index ee4a7407feba7..5bc7301ff650c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -87,10 +87,12 @@ struct WmmaLoadOpToSPIRVLowering
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
         loc, i32Type, IntegerAttr::get(i32Type, stride));
-    auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
-        loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+    bool useColMajor =
+        static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
+    auto columnMajor = rewriter.create<spirv::ConstantOp>(
+        loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
     rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
-        subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor,
+        subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
         spirv::MemoryAccessAttr());
     return success();
   }

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 771548b319b67..2734b5f1660be 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -92,6 +92,19 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
   return true;
 }
 
+// Return true if the given map represents a transposed matrix load,
+// i.e. (d0, d1, ...) -> (dn-1, dn-2).
+static bool isTransposeMatrixLoadMap(OpBuilder &b, AffineMap permutationMap) {
+  auto nDim = permutationMap.getNumDims();
+  if (nDim < 2)
+    return false;
+
+  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
+  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
+  return permutationMap ==
+         AffineMap::get(nDim, 0, {innerDim, outerDim}, b.getContext());
+}
+
 // Return the stide for the dimension 0 of |type| if it is a memref and has a
 // constant stride.
 static std::optional<int64_t>
@@ -129,9 +142,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
                                           readOp.getContext());
 
   if (!useNvGpu) {
-    // TODO: Support transpose once it is added to GPU dialect ops.
-    // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
-    return map.isMinorIdentity() || map == broadcastInnerDim;
+    bool result = map.isMinorIdentity() || map == broadcastInnerDim ||
+                  isTransposeMatrixLoadMap(b, map);
+    return result;
   }
 
   return true;
@@ -445,9 +458,10 @@ static void convertTransferReadOp(vector::TransferReadOp op,
       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
                               op.getVectorType().getElementType(), fragType);
   OpBuilder b(op);
+  bool isTranspose = isTransposeMatrixLoadMap(b, map);
   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
       op.getLoc(), type, op.getSource(), op.getIndices(),
-      b.getIndexAttr(*stride));
+      b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
   valueMapping[op.getResult()] = load;
 }
 

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index afe3d5d229b9c..fa2a40f7334d1 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -5,6 +5,7 @@
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 #map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
 #map4 = affine_map<(d0) -> (d0, 0)>
+#map5 = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: func @matmul
 //   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
@@ -170,3 +171,21 @@ func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1,
   vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16>
   return
 }
+
+// CHECK-LABEL: func @matmul_transposed
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func.func @matmul_transposed(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map5, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}


        


More information about the Mlir-commits mailing list