[Mlir-commits] [mlir] 5205c71 - [mlir][gpu] Add support for unsigned integer extend in vector to gpu.subgroup_mma lowering
Quinn Dawkins
llvmlistbot at llvm.org
Tue Feb 14 10:11:23 PST 2023
Author: Quinn Dawkins
Date: 2023-02-14T13:09:46-05:00
New Revision: 5205c7126b2fea73b283e54c4e1bf409089a8d52
URL: https://github.com/llvm/llvm-project/commit/5205c7126b2fea73b283e54c4e1bf409089a8d52
DIFF: https://github.com/llvm/llvm-project/commit/5205c7126b2fea73b283e54c4e1bf409089a8d52.diff
LOG: [mlir][gpu] Add support for unsigned integer extend in vector to gpu.subgroup_mma lowering
Unsigned integer types are supported in subgroup mma ops by matching
against arith.extui ops. This allows for subgroup_mma_compute ops with
mixed signedness which requires later conversions to handle this. SPIR-V
cooperative matrix ops support this while the lowering to WMMA does not.
Differential Revision: https://reviews.llvm.org/D143922
Added:
Modified:
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 46d40a724c2f6..bf5be54f593e9 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -60,6 +60,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isSignedInteger(8))
return NVVM::MMATypes::s8;
+ if (type.getElementType().isUnsignedInteger(8))
+ return NVVM::MMATypes::u8;
// Accumulator type is signless and implies signed.
if (type.getElementType().isInteger(32))
return NVVM::MMATypes::s32;
@@ -112,11 +114,8 @@ struct WmmaLoadOpToNVVMLowering
}
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
// Check that there is an exisiting instruction for the combination we need.
- if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0) {
- llvm::errs() << "No matching intrinsic " << m << " " << n << " " << k
- << "\n";
+ if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
- }
Type resType = convertMMAToLLVMType(retType);
Location loc = op->getLoc();
@@ -245,6 +244,12 @@ struct WmmaMmaOpToNVVMLowering
destType) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
+ NVVM::MMATypes bElementType = getElementType(
+ subgroupMmaComputeOp.getOpB().getType().cast<gpu::MMAMatrixType>());
+ if (bElementType != sourceType)
+ return rewriter.notifyMatchFailure(
+ op, "WMMA compute op input matrix element types must match.");
+
unpackOp(adaptor.getOpA());
unpackOp(adaptor.getOpB());
unpackOp(adaptor.getOpC());
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index cdd8cd77aa9c0..b0fa50d799160 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -143,7 +143,8 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
// Only allow integer types if the signedness can be inferred.
if (!useNvGpu && readOp.getVectorType().getElementType().isInteger(8))
- if (!readOp->hasOneUse() || !isa<arith::ExtSIOp>(*readOp->user_begin()))
+ if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
+ !isa<arith::ExtUIOp>(*readOp->user_begin())))
return false;
AffineMap map = readOp.getPermutationMap();
@@ -194,8 +195,9 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
return broadcastOp.getVectorType().getRank() == 2;
}
-/// Return true if this signed extend op can be folded into a contract op.
-static bool signedExtendSupportsMMAMatrixType(arith::ExtSIOp extOp) {
+/// Return true if this integer extend op can be folded into a contract op.
+template <typename ExtOpTy>
+static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
return false;
return llvm::all_of(extOp->getUsers(), [](Operation *user) {
@@ -282,8 +284,10 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return constantSupportsMMAMatrixType(constant);
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcastSupportsMMAMatrixType(broadcast);
- if (auto extend = dyn_cast<arith::ExtSIOp>(op))
- return signedExtendSupportsMMAMatrixType(extend);
+ if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
+ return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
+ if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
+ return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
return elementwiseSupportsMMAMatrixType(op);
}
@@ -429,10 +433,11 @@ struct CombineTransferReadOpTranspose final
PatternRewriter &rewriter) const override {
// Look through integer extend ops.
Value source = op.getVector();
- auto extOp = source.getDefiningOp<arith::ExtSIOp>();
auto resultType = op.getVectorType();
- if (extOp) {
- source = extOp.getOperand();
+ Operation *extOp;
+ if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
+ (extOp = source.getDefiningOp<arith::ExtUIOp>())) {
+ source = extOp->getOperand(0);
resultType =
VectorType::get(resultType.getShape(),
source.getType().cast<VectorType>().getElementType());
@@ -469,9 +474,14 @@ struct CombineTransferReadOpTranspose final
.getResult();
// Fuse through the integer extend op.
- if (extOp)
- result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
- .getResult();
+ if (extOp) {
+ if (isa<arith::ExtSIOp>(extOp))
+ result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
+ .getResult();
+ else
+ result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
+ .getResult();
+ }
rewriter.replaceOp(op, result);
return success();
@@ -484,15 +494,15 @@ struct CombineTransferReadOpTranspose final
// Figure the right layout to use by looking at op uses.
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
-template <typename OpTy>
-static const char *inferFragType(OpTy op) {
+static const char *inferFragType(Operation *op) {
for (Operation *users : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
continue;
- if (contract.getLhs() == op.getResult())
+ assert(op->getNumResults() == 1);
+ if (contract.getLhs() == op->getResult(0))
return "AOp";
- if (contract.getRhs() == op.getResult())
+ if (contract.getRhs() == op->getResult(0))
return "BOp";
}
return "COp";
@@ -521,14 +531,15 @@ static void convertTransferReadOp(vector::TransferReadOp op,
auto elType = op.getVectorType().getElementType();
const char *fragType = inferFragType(op);
if (op->hasOneUse()) {
- auto extOp = dyn_cast<arith::ExtSIOp>(*op->user_begin());
- // Infer the signedness of the mma type from the signed extend.
- if (extOp) {
- elType = IntegerType::get(op.getContext(),
- elType.cast<IntegerType>().getWidth(),
- IntegerType::Signed);
- mappingResult = extOp.getResult();
- fragType = inferFragType(extOp);
+ auto user = *op->user_begin();
+ // Infer the signedness of the mma type from the integer extend.
+ bool isSignedExtend = isa<arith::ExtSIOp>(user);
+ if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+ elType = IntegerType::get(
+ op.getContext(), elType.cast<IntegerType>().getWidth(),
+ isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+ mappingResult = user->getResult(0);
+ fragType = inferFragType(user);
}
}
gpu::MMAMatrixType type =
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d6346774c5d7b..92ab0cbbb5870 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -4028,9 +4028,19 @@ verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
typeR.getScope() != typeB.getScope() ||
typeR.getScope() != typeC.getScope())
return op.emitOpError("matrix scope must match");
- if (typeA.getElementType() != typeB.getElementType() ||
- typeR.getElementType() != typeC.getElementType())
- return op.emitOpError("matrix element type must match");
+ auto elementTypeA = typeA.getElementType();
+ auto elementTypeB = typeB.getElementType();
+ if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
+ if (elementTypeA.cast<IntegerType>().getWidth() !=
+ elementTypeB.cast<IntegerType>().getWidth())
+ return op.emitOpError(
+ "matrix A and B integer element types must be the same bit width");
+ } else if (elementTypeA != elementTypeB) {
+ return op.emitOpError(
+ "matrix A and B non-integer element types must match");
+ }
+ if (typeR.getElementType() != typeC.getElementType())
+ return op.emitOpError("matrix accumulator element type must match");
return success();
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 93cfd765d7327..c742150401d8e 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -266,3 +266,24 @@ func.func @matmul_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2:
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
return
}
+
+// CHECK-LABEL: func @matmul_mixed_signedness_int8
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xui8, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi8> -> !gpu.mma_matrix<16x16xsi8, "BOp">
+// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xi32> -> !gpu.mma_matrix<16x16xi32, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xui8, "AOp">, !gpu.mma_matrix<16x16xsi8, "BOp"> -> !gpu.mma_matrix<16x16xi32, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xi32, "COp">, memref<16x16xi32>
+func.func @matmul_mixed_signedness_int8(%arg0: memref<16x16xi8>, %arg1: memref<16x16xi8>, %arg2: memref<16x16xi32>) {
+ %cst_0 = arith.constant dense<0> : vector<16x16xi8>
+ %c0 = arith.constant 0 : index
+ %cst_i8 = arith.constant 0 : i8
+ %cst_i32 = arith.constant 0 : i32
+ %Ar = vector.transfer_read %arg0[%c0, %c0], %cst_i8 {in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+ %Br = vector.transfer_read %arg1[%c0, %c0], %cst_i8 {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xi8>, vector<16x16xi8>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst_i32 {in_bounds = [true, true]} : memref<16x16xi32>, vector<16x16xi32>
+ %Ae = arith.extui %Ar : vector<16x16xi8> to vector<16x16xi32>
+ %Be = arith.extsi %Br : vector<16x16xi8> to vector<16x16xi32>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %Ae, %Be, %C : vector<16x16xi32>, vector<16x16xi32> into vector<16x16xi32>
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32>
+ return
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 723d7d476427d..de31458b94771 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,13 +136,21 @@ spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup>
// -----
spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
- // expected-error @+1 {{matrix element type must match}}
+ // expected-error @+1 {{matrix A and B non-integer element types must match}}
%r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xf32, Subgroup>, !spirv.coopmatrix<16x8xi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
spirv.Return
}
// -----
+spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xui8, Subgroup>, %b : !spirv.coopmatrix<16x8xsi32, Subgroup>, %c : !spirv.coopmatrix<8x8xi32, Subgroup>) "None" {
+ // expected-error @+1 {{matrix A and B integer element types must be the same bit width}}
+ %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.coopmatrix<8x16xui8, Subgroup>, !spirv.coopmatrix<16x8xsi32, Subgroup> -> !spirv.coopmatrix<8x8xi32, Subgroup>
+ spirv.Return
+}
+
+// -----
+
spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
%0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer> as !spirv.coopmatrix<8x16xi32, Subgroup>
More information about the Mlir-commits
mailing list