[Mlir-commits] [mlir] 3d35546 - Support `transpose` mode for `gpu.subgroup` WMMA ops

Uday Bondhugula llvmlistbot at llvm.org
Mon Dec 5 09:07:19 PST 2022


Author: Navdeep Katel
Date: 2022-12-05T22:37:02+05:30
New Revision: 3d35546cd1680b0e087fb6c9976799760146c377

URL: https://github.com/llvm/llvm-project/commit/3d35546cd1680b0e087fb6c9976799760146c377
DIFF: https://github.com/llvm/llvm-project/commit/3d35546cd1680b0e087fb6c9976799760146c377.diff

LOG: Support `transpose` mode for `gpu.subgroup` WMMA ops

Add support for loading, computing, and storing `gpu.subgroup` WMMA ops
in transpose mode as well. Update the GPU to NVVM lowerings to support
`transpose` mode and update integration tests as well.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D139021

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
    mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 44684da5baccc..0642b1865b5fd 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1141,7 +1141,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     `!gpu.mma_matrix` is the source value containing the data to be stored into the
     destination memref which can be in global or shared memory.  The store address
     is determined using the indices provided. The `leadDimension` attribute
-    specifies the leading dimension of the destination matrix.
+    specifies the leading dimension of the destination matrix. If the
+    `transpose` attribute is present then the op does a transposed store.
 
     This op is often meant to be used along with `gpu.subgroup_mma_load_matrix` and
     `gpu.subgroup_mma_compute`.
@@ -1157,7 +1158,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
                   Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
                   Variadic<Index>:$indices,
-                  IndexAttr:$leadDimension);
+                  IndexAttr:$leadDimension,
+                  OptionalAttr<UnitAttr>:$transpose);
 
   let assemblyFormat = [{
     $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
@@ -1165,8 +1167,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   let hasVerifier = 1;
 }
 
-def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
-   [Pure, AllTypesMatch<["opC", "res"]>]>{
+def GPU_SubgroupMmaComputeOp
+    : GPU_Op<"subgroup_mma_compute", [Pure, AllTypesMatch<["opC", "res"]>]> {
 
   let summary = "GPU warp synchronous matrix multiply accumulate";
 
@@ -1175,9 +1177,14 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
     operation using all the threads in a subgroup.
 
     This operation takes three `!gpu.mma_matrix`s as arguments: these hold `A`,
-     `B` and `C`operands for the mma operation. The operation performed is represented
+    `B` and `C`operands for the mma operation. The operation performed is represented
     as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
-    the operation held by all threads in a subgroup.
+    the operation held by all threads in a subgroup. `a_transpose` or
+    `b_transpose` if present, signify that the respective operand was loaded in a
+    transposed manner. The transpose opernads are required to map to correct
+    underlying intrisics but they currently do not seem to affect correctness
+    even if they are absent given that the operands were loaded correctly using
+    the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.
 
     This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
     `gpu.subgroup_mma_load_matrix` ops.
@@ -1193,9 +1200,11 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
 
   let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
                   Arg<MMAMatrixOf<[F16, F32]>>:$opB,
-                  Arg<MMAMatrixOf<[F16, F32]>>:$opC);
+                  Arg<MMAMatrixOf<[F16, F32]>>:$opC,
+                  OptionalAttr<UnitAttr>:$a_transpose,
+                  OptionalAttr<UnitAttr>:$b_transpose);
 
-  let results = (outs GPU_MMAMatrix:$res);
+  let results = (outs GPU_MMAMatrix : $res);
 
   let assemblyFormat = [{
     $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
@@ -1215,11 +1224,11 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
     The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with
     constant elements.
 
-    The operation takes a scalar input and return a `!gpu.mma_matrix` where each
-    element of is equal to the operand constant. The destination mma_matrix type
-    must have elememt type equal to the constant type. Since the layout of
-    `!gpu.mma_matrix` is opaque this only support setting all the elements to
-    the same value.
+    The operation takes a scalar input and return a `!gpu.mma_matrix` where
+    each element of is equal to the operand constant. The destination
+    mma_matrix type must have elememt type equal to the constant type. Since
+    the layout of `!gpu.mma_matrix` is opaque this only support setting all the
+    elements to the same value.
 
     This op is meant to be used along with `gpu.subgroup_mma_compute`.
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 47687d9df4536..11e23815481c5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -77,12 +77,11 @@ struct WmmaLoadOpToNVVMLowering
     if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
       return failure();
 
-    // TODO: Support transposed mma loads.
-    if (subgroupMmaLoadMatrixOp.getTranspose())
-      return failure();
-
     // Get the shape of the MMAMatrix type being returned. The shape will
     // choose which intrinsic this op will be lowered to.
+    NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
+                                 ? NVVM::MMALayout::col
+                                 : NVVM::MMALayout::row;
     gpu::MMAMatrixType retType =
         subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> retTypeShape = retType.getShape();
@@ -105,7 +104,6 @@ struct WmmaLoadOpToNVVMLowering
       n = retTypeShape[1];
       k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
     }
-    NVVM::MMALayout layout = NVVM::MMALayout::row;
     NVVM::MMAFrag frag = convertOperand(retType.getOperand());
     // Check that there is an exisiting instruction for the combination we need.
     if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
@@ -154,7 +152,9 @@ struct WmmaStoreOpToNVVMLowering
     gpu::MMAMatrixType srcType =
         subgroupMmaStoreMatrixOp.getSrc().getType().cast<gpu::MMAMatrixType>();
     ArrayRef<int64_t> srcTypeShape = srcType.getShape();
-    NVVM::MMALayout layout = NVVM::MMALayout::row;
+    NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
+                                 ? NVVM::MMALayout::col
+                                 : NVVM::MMALayout::row;
     NVVM::MMATypes eltype = getElementType(srcType);
     int64_t m = srcTypeShape[0];
     int64_t n = srcTypeShape[1];
@@ -224,10 +224,15 @@ struct WmmaMmaOpToNVVMLowering
     int64_t m = cTypeShape[0];
     int64_t n = cTypeShape[1];
     int64_t k = aTypeShape[1];
-    NVVM::MMALayout layout = NVVM::MMALayout::row;
+    NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
+                                  ? NVVM::MMALayout::col
+                                  : NVVM::MMALayout::row;
+    NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
+                                  ? NVVM::MMALayout::col
+                                  : NVVM::MMALayout::row;
     NVVM::MMATypes sourceType = getElementType(aType);
     NVVM::MMATypes destType = getElementType(cType);
-    if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType,
+    if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
                                         destType) == 0)
       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
 
@@ -236,7 +241,7 @@ struct WmmaMmaOpToNVVMLowering
     unpackOp(adaptor.getOpC());
 
     rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
-        op, adaptor.getOpC().getType(), m, n, k, layout, layout, sourceType,
+        op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
         destType, unpackedOps);
     return success();
   }

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index 5bc7301ff650c..43d7f6237671c 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -87,10 +87,9 @@ struct WmmaLoadOpToSPIRVLowering
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
         loc, i32Type, IntegerAttr::get(i32Type, stride));
-    bool useColMajor =
-        static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
+    bool isColMajor = static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
     auto columnMajor = rewriter.create<spirv::ConstantOp>(
-        loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
+        loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor));
     rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
         subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
         spirv::MemoryAccessAttr());
@@ -118,11 +117,13 @@ struct WmmaStoreOpToSPIRVLowering
     auto i32Type = rewriter.getI32Type();
     auto strideValue = rewriter.create<spirv::ConstantOp>(
         loc, i32Type, IntegerAttr::get(i32Type, stride));
-    auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
-        loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+    bool useColMajor =
+        static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
+    auto columnMajor = rewriter.create<spirv::ConstantOp>(
+        loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
     rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
         subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
-        coloumnMajor, spirv::MemoryAccessAttr());
+        columnMajor, spirv::MemoryAccessAttr());
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index c168765458b5d..937913445f30c 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -473,9 +473,9 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
   assert(stride);
   OpBuilder b(op);
   Value matrix = valueMapping.find(op.getVector())->second;
-  b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(),
-                                          op.getIndices(),
-                                          b.getIndexAttr(*stride));
+  b.create<gpu::SubgroupMmaStoreMatrixOp>(
+      op.getLoc(), matrix, op.getSource(), op.getIndices(),
+      b.getIndexAttr(*stride), /*transpose=*/UnitAttr());
   op.erase();
 }
 
@@ -800,8 +800,9 @@ static void convertContractOp(vector::ContractionOp op,
   Value opA = valueMapping.find(op.getLhs())->second;
   Value opB = valueMapping.find(op.getRhs())->second;
   Value opC = valueMapping.find(op.getAcc())->second;
-  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
-                                                     opA, opB, opC);
+  Value matmul = b.create<gpu::SubgroupMmaComputeOp>(
+      op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
+      /*b_transpose=*/UnitAttr());
   valueMapping[op.getResult()] = matmul;
 }
 

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index 664a4dab3f91c..c2d7ec555942f 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -10,7 +10,7 @@ gpu.module @test_module {
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index
     %j = arith.constant 16 : index
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
@@ -20,7 +20,7 @@ gpu.module @test_module {
     // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
     // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
-    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 
     // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
@@ -32,7 +32,7 @@ gpu.module @test_module {
     // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
     // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
     // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
-    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
   }
@@ -50,7 +50,7 @@ gpu.module @test_module {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index
     %j = arith.constant 16 : index
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK:  %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -64,7 +64,7 @@ gpu.module @test_module {
     // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
     // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
     // CHECK:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
-    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+    // CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
     // CHECK:  llvm.return
 
     // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
@@ -80,7 +80,7 @@ gpu.module @test_module {
     // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
     // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
     // CHECK32:  nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
-    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
+    // CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
     // CHECK32:  llvm.return
     return
   }
@@ -93,7 +93,7 @@ gpu.module @test_module {
   // CHECK-LABEL: func @gpu_wmma_mma_op
   // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
   func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
-    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -115,7 +115,7 @@ gpu.module @test_module {
     // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]]
-    // CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
+    // CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
     // CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     return %D : !gpu.mma_matrix<16x16xf16, "COp">

diff  --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
index 0c4b0563b0b19..c4dc7458bc312 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
@@ -12,7 +12,8 @@ module attributes {
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
-      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
       %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
       // CHECK: spirv.Return
       gpu.return
@@ -22,6 +23,29 @@ module attributes {
 
 // -----
 
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_load_op_transpose
+    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+    gpu.func @gpu_wmma_load_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, %[[COLMAJOR]] :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
 module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
@@ -35,7 +59,8 @@ module attributes {
       attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
       %i = arith.constant 16 : index
       %j = arith.constant 16 : index
-      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant false
+      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
       gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
       // CHECK: spirv.Return
       gpu.return
@@ -45,6 +70,30 @@ module attributes {
 
 // -----
 
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_store_op_transpose
+    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>
+    gpu.func @gpu_wmma_store_op_transpose(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK: %[[COLMAJOR:.*]] = spirv.Constant true
+      // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, %[[COLMAJOR]] : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
 module attributes {
   gpu.container_module,
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
@@ -107,4 +156,4 @@ module attributes {
       gpu.return
     }
   }
-}
\ No newline at end of file
+}

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index 00fc729ed158d..8b13a1a362fdd 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -22,10 +22,12 @@ func.func @main() {
   %c32 = arith.constant 32 : index
   %c1 = arith.constant 1 : index
 
-  // Intialize the Input matrix with ones.
+  // Intialize the Input matrix with the column index in each row.
   scf.for %arg0 = %c0 to %c16 step %c1 {
     scf.for %arg1 = %c0 to %c16 step %c1 {
-      memref.store %f1, %0[%arg0, %arg1] : memref<16x16xf16>
+      %2 = arith.index_cast %arg1 : index to i16
+      %3 = arith.sitofp %2 : i16 to f16
+      memref.store %3, %0[%arg0, %arg1] : memref<16x16xf16>
     }
   }
   // Intialize the accumulator matrix with zeros.
@@ -43,11 +45,11 @@ func.func @main() {
 
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
-    %A = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %A = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
     %B = gpu.subgroup_mma_load_matrix %0[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
     %C = gpu.subgroup_mma_load_matrix %22[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 
-    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+    %R = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 
     gpu.subgroup_mma_store_matrix %R, %0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
     gpu.terminator
@@ -64,22 +66,22 @@ func.func @main() {
 
   // Print the memref after computation.
   call @printMemrefF32(%3) : (memref<*xf32>) -> ()
-  // CHECK: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16],
-  // CHECK-NEXT: [16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16,   16]
+  // CHECK:      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+  // CHECK-NEXT: [0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240],
+  // CHECK-NEXT: [0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 480],
+  // CHECK-NEXT: [0, 48, 96, 144, 192, 240, 288, 336, 384, 432, 480, 528, 576, 624, 672, 720],
+  // CHECK-NEXT: [0, 64, 128, 192, 256, 320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960],
+  // CHECK-NEXT: [0, 80, 160, 240, 320, 400, 480, 560, 640, 720, 800, 880, 960, 1040, 1120, 1200],
+  // CHECK-NEXT: [0, 96, 192, 288, 384, 480, 576, 672, 768, 864, 960, 1056, 1152, 1248, 1344, 1440],
+  // CHECK-NEXT: [0, 112, 224, 336, 448, 560, 672, 784, 896, 1008, 1120, 1232, 1344, 1456, 1568, 1680],
+  // CHECK-NEXT: [0, 128, 256, 384, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920],
+  // CHECK-NEXT: [0, 144, 288, 432, 576, 720, 864, 1008, 1152, 1296, 1440, 1584, 1728, 1872, 2016, 2160],
+  // CHECK-NEXT: [0, 160, 320, 480, 640, 800, 960, 1120, 1280, 1440, 1600, 1760, 1920, 2080, 2240, 2400],
+  // CHECK-NEXT: [0, 176, 352, 528, 704, 880, 1056, 1232, 1408, 1584, 1760, 1936, 2112, 2288, 2464, 2640],
+  // CHECK-NEXT: [0, 192, 384, 576, 768, 960, 1152, 1344, 1536, 1728, 1920, 2112, 2304, 2496, 2688, 2880],
+  // CHECK-NEXT: [0, 208, 416, 624, 832, 1040, 1248, 1456, 1664, 1872, 2080, 2288, 2496, 2704, 2912, 3120],
+  // CHECK-NEXT: [0, 224, 448, 672, 896, 1120, 1344, 1568, 1792, 2016, 2240, 2464, 2688, 2912, 3136, 3360],
+  // CHECK-NEXT: [0, 240, 480, 720, 960, 1200, 1440, 1680, 1920, 2160, 2400, 2640, 2880, 3120, 3360, 3600]]
   return
 }
 


        


More information about the Mlir-commits mailing list