[Mlir-commits] [mlir] 985f7ff - [mlir][gpu] Add support for integer types in gpu.subgroup_mma ops

Quinn Dawkins llvmlistbot at llvm.org
Tue Feb 7 15:05:36 PST 2023


Author: Quinn Dawkins
Date: 2023-02-07T17:58:01-05:00
New Revision: 985f7ff6326e91b0d508aa1b405f4f26ed683ca6

URL: https://github.com/llvm/llvm-project/commit/985f7ff6326e91b0d508aa1b405f4f26ed683ca6
DIFF: https://github.com/llvm/llvm-project/commit/985f7ff6326e91b0d508aa1b405f4f26ed683ca6.diff

LOG: [mlir][gpu] Add support for integer types in gpu.subgroup_mma ops

The signedness is carried by `!gpu.mma_matrix` types to most closely
match the Cooperative Matrix specification which determines signedness
with the type (and sometimes the operation).

See: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.html

To handle the lowering from vector to gpu, ops such as arith.extsi are
pattern matched next to `vector.transfer_read` and `vector.contract` to
determine the signedness of the matrix type.

Enables s8 and u8 WMMA types in NVVM for the GPUToNVVM conversion.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143223

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
    mlir/test/Dialect/GPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index e680a2a3c56fa..a04d8d97afa5c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -101,7 +101,7 @@ def GPU_MMAMatrix : DialectType<
   GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
 
 // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
-def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>;
+def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
 
 class MMAMatrixOf<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 498117eb43ca6..32ab246c74f05 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1150,6 +1150,10 @@ def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
     matrix which eventually allows the lowering to determine the size of each
     row.  If the `transpose` attribute is present then the op does a transposed load.
 
+    For integer types, the resulting `!gpu.mma_matrix` type needs to specify the
+    signedness of the data if the matrix type is an `A` or `B` operand for
+    `gpu.subgroup_mma_compute`.
+
     This op is often meant to be used along with `gpu.subgroup_mma_store_matrix` and
     `gpu.subgroup_mma_compute`.
 
@@ -1201,7 +1205,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
                   Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension,
@@ -1227,11 +1231,15 @@ def GPU_SubgroupMmaComputeOp
     as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
     the operation held by all threads in a subgroup. `a_transpose` or
     `b_transpose` if present, signify that the respective operand was loaded in a
-    transposed manner. The transpose opernads are required to map to correct
+    transposed manner. The transpose operands are required to map to correct
     underlying intrisics but they currently do not seem to affect correctness
     even if they are absent given that the operands were loaded correctly using
     the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.
 
+    For integer types, the `A` and `B` matrices carry their signedness with their
+    types. The accumulator type is expected to be signless and imply a signed integer
+    with a greater width than the other two operands.
+
     This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
     `gpu.subgroup_mma_load_matrix` ops.
 
@@ -1244,9 +1252,9 @@ def GPU_SubgroupMmaComputeOp
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
-                  Arg<MMAMatrixOf<[F16, F32]>>:$opB,
-                  Arg<MMAMatrixOf<[F16, F32]>>:$opC,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
+                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
+                  Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
                   OptionalAttr<UnitAttr>:$a_transpose,
                   OptionalAttr<UnitAttr>:$b_transpose);
 
@@ -1288,7 +1296,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
+  let arguments = (ins AnyTypeOf<[SI8, UI8, I32, F16, F32]>:$value);
 
   let results = (outs GPU_MMAMatrix:$res);
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 81c6d4af51f30..7d82e82eba225 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -37,7 +37,8 @@ enum NVVMMemorySpace {
 /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
 /// WMMA_REGS structure.
 std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
-                                             mlir::NVVM::MMAFrag frag,
+                                             mlir::NVVM::MMAFrag frag, int nRow,
+                                             int nCol,
                                              mlir::MLIRContext *context);
 } // namespace NVVM
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 912cd585e3c13..32f826d76be51 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -385,16 +385,20 @@ class NVVM_MMA_OPS {
   list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
             [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
             ["f16"], [], ["f16", "f32"], []>.ret;
+  list<list<WMMA_REGS>> i8_wmma_ops = MMA_OPS<
+            [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
+            ["s8","u8"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
             tf32_wmma_ops,
-            fp_wmma_ops);
+            fp_wmma_ops,
+            i8_wmma_ops);
 
   list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
             [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
-            ["a", "b"], ["f16"]>.ret;
+            ["a", "b"], ["f16","s8","u8"]>.ret;
   list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
             [GEOM<16, 16, 16>, GEOM<32, 8, 16>, GEOM<8, 32, 16>],
-            ["c", "d"], ["f16", "f32"]>.ret;
+            ["c", "d"], ["f16", "f32","s32"]>.ret;
   list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
             [GEOM<16, 16, 8>],
             ["a", "b"], ["tf32"]>.ret;

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 11e23815481c5..46d40a724c2f6 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -57,6 +57,12 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF32())
     return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
                                            : NVVM::MMATypes::tf32;
+
+  if (type.getElementType().isSignedInteger(8))
+    return NVVM::MMATypes::s8;
+  // Accumulator type is signless and implies signed.
+  if (type.getElementType().isInteger(32))
+    return NVVM::MMATypes::s32;
   llvm_unreachable("Unsupported type");
 }
 
@@ -106,8 +112,11 @@ struct WmmaLoadOpToNVVMLowering
     }
     NVVM::MMAFrag frag = convertOperand(retType.getOperand());
     // Check that there is an exisiting instruction for the combination we need.
-    if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
+    if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
+      llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
+                   << "\n";
       return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+    }
 
     Type resType = convertMMAToLLVMType(retType);
     Location loc = op->getLoc();
@@ -366,8 +375,10 @@ struct WmmaElementwiseOpToNVVMLowering
 LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
   NVVM::MMAFrag frag = convertOperand(type.getOperand());
   NVVM::MMATypes eltType = getElementType(type);
+  auto nRow = type.getShape()[0];
+  auto nCol = type.getShape()[1];
   std::pair<Type, unsigned> typeInfo =
-      NVVM::inferMMAType(eltType, frag, type.getContext());
+      NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
   return LLVM::LLVMStructType::getLiteral(
       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
 }

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 0a6cdae2ebffc..cdd8cd77aa9c0 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -140,6 +140,12 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
     return false;
   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
     return false;
+
+  // Only allow integer types if the signedness can be inferred.
+  if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
+    if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
+      return false;
+
   AffineMap map = readOp.getPermutationMap();
   OpBuilder b(readOp.getContext());
   AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
@@ -185,8 +191,16 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
 
 /// Return true if this is a broadcast from scalar to a 2D vector.
 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
-  return broadcastOp.getVectorType().getRank() == 2 &&
-         broadcastOp.getSource().getType().isa<FloatType>();
+  return broadcastOp.getVectorType().getRank() == 2;
+}
+
+/// Return true if this signed extend op can be folded into a contract op.
+static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
+  if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
+    return false;
+  return llvm::all_of(extOp->getUsers(), [](Operation *user) {
+    return isa<vector::ContractionOp>(user);
+  });
 }
 
 /// Return the MMA elementwise enum associated with `op` if it is supported.
@@ -268,6 +282,8 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
     return constantSupportsMMAMatrixType(constant);
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
     return broadcastSupportsMMAMatrixType(broadcast);
+  if (auto extend = dyn_cast<arith::ExtSIOp>(op))
+    return signedExtendSupportsMMAMatrixType(extend);
   return elementwiseSupportsMMAMatrixType(op);
 }
 
@@ -411,8 +427,18 @@ struct CombineTransferReadOpTranspose final
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    auto transferReadOp =
-        op.getVector().getDefiningOp<vector::TransferReadOp>();
+    // Look through integer extend ops.
+    Value source = op.getVector();
+    auto extOp = source.getDefiningOp<arith::ExtSIOp>();
+    auto resultType = op.getVectorType();
+    if (extOp) {
+      source = extOp.getOperand();
+      resultType =
+          VectorType::get(resultType.getShape(),
+                          source.getType().cast<VectorType>().getElementType());
+    }
+
+    auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
     if (!transferReadOp)
       return failure();
 
@@ -431,11 +457,23 @@ struct CombineTransferReadOpTranspose final
         AffineMap::getPermutationMap(permU, op.getContext());
     AffineMap newMap =
         permutationMap.compose(transferReadOp.getPermutationMap());
-    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-        op, op.getType(), transferReadOp.getSource(),
-        transferReadOp.getIndices(), AffineMapAttr::get(newMap),
-        transferReadOp.getPadding(), transferReadOp.getMask(),
-        transferReadOp.getInBoundsAttr());
+
+    auto loc = op.getLoc();
+    Value result =
+        rewriter
+            .create<vector::TransferReadOp>(
+                loc, resultType, transferReadOp.getSource(),
+                transferReadOp.getIndices(), AffineMapAttr::get(newMap),
+                transferReadOp.getPadding(), transferReadOp.getMask(),
+                transferReadOp.getInBoundsAttr())
+            .getResult();
+
+    // Fuse through the integer extend op.
+    if (extOp)
+      result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
+                   .getResult();
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -479,14 +517,26 @@ static void convertTransferReadOp(vector::TransferReadOp op,
     stride = 0;
   }
   assert(stride);
+  Value mappingResult = op.getResult();
+  auto elType = op.getVectorType().getElementType();
   const char *fragType = inferFragType(op);
+  if (op->hasOneUse()) {
+    auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
+    // Infer the signedness of the mma type from the signed extend.
+    if (extOp) {
+      elType = IntegerType::get(op.getContext(),
+                                elType.cast<IntegerType>().getWidth(),
+                                IntegerType::Signed);
+      mappingResult = extOp.getResult();
+      fragType = inferFragType(extOp);
+    }
+  }
   gpu::MMAMatrixType type =
-      gpu::MMAMatrixType::get(op.getVectorType().getShape(),
-                              op.getVectorType().getElementType(), fragType);
+      gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
       op.getLoc(), type, op.getSource(), op.getIndices(),
       b.getIndexAttr(*stride), isTranspose ? b.getUnitAttr() : UnitAttr());
-  valueMapping[op.getResult()] = load;
+  valueMapping[mappingResult] = load;
 }
 
 static void convertTransferWriteOp(vector::TransferWriteOp op,

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 0ae8e16111148..a64e5e4d74707 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -78,7 +78,9 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
 
 bool MMAMatrixType::isValidElementType(Type elementType) {
-  return elementType.isF16() || elementType.isF32();
+  return elementType.isF16() || elementType.isF32() ||
+         elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
+         elementType.isInteger(32);
 }
 
 LogicalResult
@@ -93,7 +95,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "MMAMatrixType must have exactly two dimensions";
 
   if (!MMAMatrixType::isValidElementType(elementType))
-    return emitError() << "MMAMatrixType elements must be F16 or F32";
+    return emitError()
+           << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 60d619732e9bc..7ff949e1fbfe2 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -537,7 +537,8 @@ LogicalResult ShflOp::verify() {
 }
 
 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
-                                                   NVVM::MMAFrag frag,
+                                                   NVVM::MMAFrag frag, int nRow,
+                                                   int nCol,
                                                    MLIRContext *context) {
   unsigned numberElements = 0;
   Type elementType;
@@ -555,11 +556,48 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
   } else if (type == NVVM::MMATypes::tf32) {
     elementType = builder.getI32Type();
     numberElements = 4;
+  } else if (type == NVVM::MMATypes::s8 || type == NVVM::MMATypes::u8) {
+    elementType = builder.getI32Type();
+    int parallelSize = 0;
+    if (frag == NVVM::MMAFrag::a)
+      parallelSize = nRow;
+    if (frag == NVVM::MMAFrag::b)
+      parallelSize = nCol;
+
+    // m == 16 && n == 16 && k == 16
+    if (parallelSize == 16)
+      numberElements = 2;
+    // m == 8 && n == 32 && k == 16 or m == 32 && n == 8 && k == 16
+    else if (parallelSize == 8)
+      numberElements = 1;
+    else if (parallelSize == 32)
+      numberElements = 4;
+  } else if (type == NVVM::MMATypes::s32) {
+    elementType = builder.getI32Type();
+    numberElements = 8;
   }
   assert(numberElements != 0 && elementType != nullptr);
   return std::make_pair(elementType, numberElements);
 }
 
+static std::pair<mlir::Type, unsigned>
+inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
+                    int k, MLIRContext *context) {
+  int nRow, nCol;
+  if (frag == NVVM::MMAFrag::a) {
+    nRow = m;
+    nCol = k;
+  } else if (frag == NVVM::MMAFrag::b) {
+    nRow = k;
+    nCol = n;
+  } else {
+    nRow = m;
+    nCol = n;
+  }
+  assert(nRow && nCol);
+  return inferMMAType(type, frag, nRow, nCol, context);
+}
+
 LogicalResult NVVM::WMMALoadOp::verify() {
   unsigned addressSpace =
       getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
@@ -570,8 +608,8 @@ LogicalResult NVVM::WMMALoadOp::verify() {
   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
                                        getEltype(), getFrag()) == 0)
     return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfo =
-      inferMMAType(getEltype(), getFrag(), getContext());
+  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+      getEltype(), getFrag(), getM(), getN(), getK(), getContext());
   Type dstType = LLVM::LLVMStructType::getLiteral(
       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
   if (getType() != dstType)
@@ -590,8 +628,8 @@ LogicalResult NVVM::WMMAStoreOp::verify() {
   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
                                         getEltype()) == 0)
     return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfo =
-      inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
+  std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
+      getEltype(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
   if (getArgs().size() != typeInfo.second)
     return emitOpError() << "expected " << typeInfo.second << " data operands";
   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
@@ -606,12 +644,12 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
                                       getLayoutB(), getEltypeA(),
                                       getEltypeB()) == 0)
     return emitOpError() << "invalid attribute combination";
-  std::pair<Type, unsigned> typeInfoA =
-      inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
-  std::pair<Type, unsigned> typeInfoB =
-      inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
-  std::pair<Type, unsigned> typeInfoC =
-      inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
+  std::pair<Type, unsigned> typeInfoA = inferMMATypeFromMNK(
+      getEltypeA(), NVVM::MMAFrag::a, getM(), getN(), getK(), getContext());
+  std::pair<Type, unsigned> typeInfoB = inferMMATypeFromMNK(
+      getEltypeA(), NVVM::MMAFrag::b, getM(), getN(), getK(), getContext());
+  std::pair<Type, unsigned> typeInfoC = inferMMATypeFromMNK(
+      getEltypeB(), NVVM::MMAFrag::c, getM(), getN(), getK(), getContext());
   SmallVector<Type, 32> arguments;
   arguments.append(typeInfoA.second, typeInfoA.first);
   arguments.append(typeInfoB.second, typeInfoB.first);

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index c2d7ec555942f..0b4456a818d58 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -40,6 +40,45 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_int8_load_op() ->
+  // CHECK-SAME: !llvm.struct<(i32, i32)>
+  // CHECK32-LABEL: func @gpu_wmma_int8_load_op() ->
+  func.func @gpu_wmma_int8_load_op() -> (!gpu.mma_matrix<16x16xsi8, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xi8, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp">
+    // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
+    // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i64, array<2 x i64>, array<2 x i64>)>
+    // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64
+    // CHECK:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i64
+    // CHECK:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i64
+    // CHECK:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i64) -> !llvm.ptr<i8, 3>
+    // CHECK:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+    // CHECK-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
+    // CHECK:  llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
+
+    // CHECK32:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
+    // CHECK32: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
+    // CHECK32:  %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 3>, ptr<i8, 3>, i32, array<2 x i32>, array<2 x i32>)>
+    // CHECK32:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]]  : i32
+    // CHECK32:  %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]]  : i32
+    // CHECK32:  %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<i8, 3>, i32) -> !llvm.ptr<i8, 3>
+    // CHECK32:  %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK32:  %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
+    // CHECK32-SAME: {eltype = #nvvm.mma_type<s8>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32}  : (!llvm.ptr<i8, 3>) -> !llvm.struct<(i32, i32)>
+    // CHECK32:  llvm.return %[[FRAG]] : !llvm.struct<(i32, i32)>
+    return %0 : !gpu.mma_matrix<16x16xsi8, "AOp">
+  }
+}
+
+// -----
+
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
@@ -124,6 +163,35 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_mma_int8_op
+  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(i32, i32, i32, i32)>, %[[B:.*]]: !llvm.struct<(i32)>, %[[C:.*]]: !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>)
+  func.func @gpu_wmma_mma_int8_op(%A : !gpu.mma_matrix<32x16xsi8, "AOp">, %B : !gpu.mma_matrix<16x8xsi8, "BOp">, %C : !gpu.mma_matrix<32x8xi32, "COp">) -> (!gpu.mma_matrix<32x8xi32, "COp">) {
+    %D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<32x16xsi8, "AOp">, !gpu.mma_matrix<16x8xsi8, "BOp"> -> !gpu.mma_matrix<32x8xi32, "COp">
+    // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(i32, i32, i32, i32)>
+    // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(i32, i32, i32, i32)>
+    // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(i32, i32, i32, i32)>
+    // CHECK:  %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(i32, i32, i32, i32)>
+    // CHECK:  %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(i32)>
+    // CHECK:  %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C5:.*]] = llvm.extractvalue %[[C]][4] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C6:.*]] = llvm.extractvalue %[[C]][5] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C7:.*]] = llvm.extractvalue %[[C]][6] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[C8:.*]] = llvm.extractvalue %[[C]][7] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[B1]], %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], %[[C7]], %[[C8]]
+    // CHECK-SAME: {eltypeA = #nvvm.mma_type<s8>, eltypeB = #nvvm.mma_type<s32>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 32 : i32, n = 8 : i32} : (
+    // CHECK-SAME: i32, {{.*}}) -> !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    // CHECK:  llvm.return %[[RES]] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)>
+    return %D : !gpu.mma_matrix<32x8xi32, "COp">
+  }
+}
+
+// -----
+
 gpu.module @test_module {
 
 // CHECK-LABEL: func @gpu_wmma_mma_loop_op

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index b00d34f23832c..93cfd765d7327 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -225,3 +225,44 @@ func.func @matmul_transposed_broadcasted_2d(%arg0: memref<32x32xf16>, %arg1: mem
   vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
   return
 }
+
+// Do not convert to subgroup_mma ops with integer types if signedness cannot be inferred.
+// CHECK-LABEL: func @matmul_no_extend_int8
+//   CHECK-DAG:   %[[A:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+//   CHECK-DAG:   %[[B:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+//   CHECK-DAG:   %[[C:.+]] = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
+//       CHECK:   %[[D:.+]] = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
+//       CHECK:   vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
+func.func @matmul_no_extend_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<16x16xi8>
+  %c0 = arith.constant 0 : index
+  %cst_i8 = arith.constant 0 : i8
+  %cst_i32 = arith.constant 0 : i32
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xi8>, vector<16x16xi8> into vector<16x16xi32>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
+  return
+}
+
+// CHECK-LABEL: func @matmul_int8
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xsi8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32>
+func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
+  %cst_0 = arith.constant dense<0> : vector<16x16xi8>
+  %c0 = arith.constant 0 : index
+  %cst_i8 = arith.constant 0 : i8
+  %cst_i32 = arith.constant 0 : i32
+  %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
+  %Ae = arith.extsi %Ar : vector<16x16xi8> to vector<16x16xi32>
+  %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
+  return
+}

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index a139f4c3d8546..4a52455ad0b33 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -485,8 +485,8 @@ func.func @mmamatrix_operand_type(){
 func.func @mmamatrix_invalid_element_type(){
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index
-    // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}}
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp">
+    // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}}
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp">
     return
 }
 
@@ -505,7 +505,7 @@ func.func @mmaLoadOp_identity_layout(){
 // -----
 
 func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) {
-    // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}}
+    // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}}
     %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp">
     return
 }


        


More information about the Mlir-commits mailing list