[Mlir-commits] [mlir] b44007b - [mlir][gpu] Relax restriction on MMA store op to allow chain of mma ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 27 09:14:17 PDT 2021


Author: thomasraoux
Date: 2021-05-27T09:13:51-07:00
New Revision: b44007bec2470db0d9f100c6a9216d8e05cef608

URL: https://github.com/llvm/llvm-project/commit/b44007bec2470db0d9f100c6a9216d8e05cef608
DIFF: https://github.com/llvm/llvm-project/commit/b44007bec2470db0d9f100c6a9216d8e05cef608.diff

LOG: [mlir][gpu] Relax restriction on MMA store op to allow chain of mma ops.

In order to allow large matmul operations using the MMA ops we need to chain
operations this is not possible unless "DOp" and "COp" type have matching
layout so remove the "DOp" layout and force accumulator and result type to
match.
Added a test for the case where the MMA value is accumulated.

Differential Revision: https://reviews.llvm.org/D103023

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
    mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 7832d48945656..c6d84f2cbcdc2 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -85,9 +85,9 @@ struct MMAMatrixStorageType : public TypeStorage {
   Type elementType;
 
   /// MMA operand that this MMAMatrix holds. The general form of operation this
-  /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
-  /// field specifies which operand in the given equation is held by this type.
-  /// The valid values are "AOp", "BOp", "COp" and "DOp".
+  /// type supports is given by the equation C += A*B. This field specifies
+  /// which operand in the given equation is held by this type. The valid values
+  /// are "AOp", "BOp" and "COp".
   StringRef operand;
 };
 
@@ -112,13 +112,13 @@ struct MMAMatrixStorageType : public TypeStorage {
 /// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
 /// are:-
 ///
-///   %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
-///   "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
-///                             "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+///   %3 = gpu.subgroup_mma_compute %0, %1, %2 :
+///   !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
+///    -> !gpu.mma_matrix<16x16xf32, "COp">
 ///
 ///
 ///   gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
-///           : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 // TODO: consider moving this to ODS.
 class MMAMatrixType
     : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
@@ -154,9 +154,8 @@ class MMAMatrixType
   Type getElementType() const;
 
   /// The general form of operation this type supports is given by the equation
-  /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
-  /// given equation is held by this type. String returned can be one of"AOp",
-  /// "BOp", "COp" and "DOp".
+  /// C += A*B. This function returns which operand in the given equation is
+  /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
   StringRef getOperand() const;
 };
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index a29a22a9989bc..bc6f3f3169e1c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -966,7 +966,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
 
     ```mlir
     gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
-    !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     ```
   }];
 
@@ -982,7 +982,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   let verifier = [{ return ::verify(*this); }];
 }
 
-def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
+def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
+   [NoSideEffect, AllTypesMatch<["opC", "res"]>]>{
 
   let summary = "GPU warp synchronous matrix multiply accumulate";
 
@@ -992,7 +993,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
 
     This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
      `B` and `C`operands for the mma operation. The operation performed is represented
-    as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
+    as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
     the operation held by the current thread.
 
     This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
@@ -1002,8 +1003,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
 
     ```mlir
     %D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
-    !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
-    !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">>
+    -> !gpu.mma_matrix<16x16xf16, "COp">
     ```
   }];
 
@@ -1014,7 +1015,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
   let results = (outs GPU_MMAMatrix:$res);
 
   let assemblyFormat = [{
-    $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
+    $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
   }];
 
   let verifier = [{ return ::verify(*this); }];

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 42e64e5fb3c6a..ea336dc68f0e2 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -135,11 +135,9 @@ struct LowerGpuOpsToNVVMOpsPass
       numElemsPerThreadF16["AOp"] = 8;
       numElemsPerThreadF16["BOp"] = 8;
       numElemsPerThreadF16["COp"] = 4;
-      numElemsPerThreadF16["DOp"] = 4;
       numElemsPerThreadF32["AOp"] = 8;
       numElemsPerThreadF32["BOp"] = 8;
       numElemsPerThreadF32["COp"] = 8;
-      numElemsPerThreadF32["DOp"] = 8;
       Type structToReturn;
       if (type.getElementType().isF16()) {
         // Number of f16's in 32-bit.

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index c4458aa05f96d..5f5213a290fd7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -29,7 +29,6 @@ struct CommonLLVMAndBuiltInMLIRTypes {
     numHalfsInOpFrags[A] = 8;
     numHalfsInOpFrags[B] = 8;
     numHalfsInOpFrags[C] = 4;
-    numHalfsInOpFrags[D] = 4;
     i32Ty = IntegerType::get(context, 32);
     f16Ty = FloatType::getF16(context);
     f32Ty = FloatType::getF32(context);
@@ -63,7 +62,7 @@ struct CommonLLVMAndBuiltInMLIRTypes {
   SmallVector<unsigned, 4> numHalfsInOpFrags;
   /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
   /// (beta*C).
-  enum OperandMap { A, B, C, D };
+  enum OperandMap { A, B, C };
 };
 
 /// Checks if all the operands of the op being lowered are of LLVM Types. The
@@ -305,7 +304,7 @@ struct WmmaStoreOpToNVVMLowering
             .getType()
             .cast<gpu::MMAMatrixType>()
             .getElementType() == f16Ty) {
-      for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
+      for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
         Value toUse = rewriter.create<LLVM::ExtractValueOp>(
             loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
         storeOpOperands.push_back(toUse);

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 1f081d896bfc2..39acf4182863d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -64,8 +64,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
                       ArrayRef<int64_t> shape, Type elementType,
                       StringRef operand) {
   if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp") && !operand.equals("DOp"))
-    return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
+      !operand.equals("COp"))
+    return emitError() << "operand expected to be one of AOp, BOp or COp";
 
   if (shape.size() != 2)
     return emitError() << "MMAMatrixType must have exactly two dimensions";
@@ -1027,9 +1027,9 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
         "destination memorySpace of kGenericMemorySpace, "
         "kGlobalMemorySpace or kSharedMemorySpace only allowed");
 
-  if (!srcMatrixType.getOperand().equals("DOp"))
+  if (!srcMatrixType.getOperand().equals("COp"))
     return op.emitError(
-        "expected the operand matrix being stored to have 'DOp' operand type");
+        "expected the operand matrix being stored to have 'COp' operand type");
 
   return success();
 }

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index c44d7f0cfa301..de5d0d3fcf1c0 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -31,11 +31,11 @@ gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
   // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
-  func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+  func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
     %j = constant 16 : index
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
@@ -61,9 +61,9 @@ gpu.module @test_module {
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_mma_op
-  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
-  func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
-    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+  func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -84,8 +84,70 @@ gpu.module @test_module {
     // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  llvm.return
-    return
+    // CHECK:  %[[RES:.*]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    return %D : !gpu.mma_matrix<16x16xf16, "COp">
   }
 }
+
+// -----
+
+gpu.module @test_module {
+
+// CHECK-LABEL: func @gpu_wmma_mma_loop_op
+//       CHECK:   %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:  ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
+//       CHECK:   llvm.cond_br %38, ^bb2, ^bb3
+//       CHECK:  ^bb2:  // pred: ^bb1
+//       CHECK:   %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A3:.+]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A4:.+]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A5:.+]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A6:.+]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A7:.+]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B0:.+]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B1:.+]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B2:.+]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B3:.+]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B4:.+]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B5:.+]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B6:.+]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B7:.+]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:  ^bb3:  // pred: ^bb1
+//       CHECK:   %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+
+  func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
+      %c0 = constant 0 : index
+      %c128 = constant 128 : index
+      %c32 = constant 32 : index
+      %0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+      br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
+    ^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">):  // 2 preds: ^bb0, ^bb2
+      %3 = cmpi slt, %1, %c128 : index
+      cond_br %3, ^bb2, ^bb3
+    ^bb2:  // pred: ^bb1
+      %4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+      %5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+      %6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+      %7 = addi %1, %c32 : index
+      br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
+    ^bb3:  // pred: ^bb1
+      gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
+      return
+    }
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 58eca3b875855..f399ddd8ba76b 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -474,7 +474,7 @@ func @mmamatrix_invalid_shape(){
 func @mmamatrix_operand_type(){
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
-    // expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
+    // expected-error @+1 {{operand expected to be one of AOp, BOp or COp}}
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
     return
 }
@@ -513,35 +513,25 @@ func @mmaLoadOp_invalid_mem_space(){
 
 // -----
 
-func @mmaLoadOp_operand_type(){
-    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
-    %i = constant 16 : index
-    // expected-error @+1 {{only AOp, BOp and COp can be loaded}}
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
-    return
-}
-
-// -----
-
 #layout_map_col_major = affine_map<(i, j) -> (j, i)>
 
-func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
     %i = constant 16 : index
     %j = constant 16 : index
     // expected-error @+1 {{expected identity layout map for destination memref}}
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
     return
 }
 
 // -----
 
-func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
     %i = constant 16 : index
     %j = constant 16 : index
     // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
     return
 }
 
@@ -551,7 +541,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
     %j = constant 16 : index
-    // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
+    // expected-error @+1 {{expected the operand matrix being stored to have 'COp' operand type}}
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
     return
 }
@@ -560,7 +550,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
 
 func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     // expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
-    %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     return
 }
 
@@ -568,6 +558,6 @@ func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B
 
 func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     // expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
-    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     return
 }

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
index 9e6802a3dac07..3dc911b32b69b 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
@@ -82,9 +82,9 @@ module attributes {gpu.container_module}  {
       %1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
       %2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 
-      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 
-      gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "DOp">, memref<16x16xf16>
+      gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
 
       gpu.return
     }

diff  --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
index 5ca0147889b34..ba948ea8997b1 100644
--- a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir
@@ -73,9 +73,9 @@ module attributes {gpu.container_module}  {
       %1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
       %2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
 
-      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
 
-      gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+      gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 
       gpu.return
     }


        


More information about the Mlir-commits mailing list