[Mlir-commits] [mlir] 4e4af13 - [mlir][gpu][nvvm] fixed bug with literal for inline asm for mma instruction
Aart Bik
llvmlistbot at llvm.org
Fri Mar 17 09:22:31 PDT 2023
Author: Aart Bik
Date: 2023-03-17T09:22:15-07:00
New Revision: 4e4af1338da5bdbf10e113c0462d7eb7222b5d97
URL: https://github.com/llvm/llvm-project/commit/4e4af1338da5bdbf10e113c0462d7eb7222b5d97
DIFF: https://github.com/llvm/llvm-project/commit/4e4af1338da5bdbf10e113c0462d7eb7222b5d97.diff
LOG: [mlir][gpu][nvvm] fixed bug with literal for inline asm for mma instruction
The 'mma.sp.sync.aligned' family of instructions expects
the sparsity selector as a direct literal (0x0 or 0x1).
The current MLIR inline asm passed this as a value in
register, which broke the downstream assemblers
This is a small step towards supporting 2:4 sparsity on
NVidia GPUs in the sparse compiler of MLIR.
Reviewed By: ThomasRaoux, guraypp
Differential Revision: https://reviews.llvm.org/D146110
Added:
Modified:
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index d27af456f41b9..e88bb05d0b0b9 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -401,22 +401,22 @@ static std::string buildMmaSparseAsmConstraintString(unsigned matASize,
ss << "=r,";
for (unsigned i = 0; i < matASize + matBSize + matCSize; i++)
ss << "r,";
- // The final two operands are for the sparsity metadata and sparsity selector.
- ss << "r,r";
+ // The final operand is for the sparsity metadata.
+ // The sparsity selector appears as direct literal.
+ ss << "r";
ss.flush();
return str;
}
/// Returns the string for the `mma.sp.sync` instruction that corresponds to
-/// the give parameters. Note that this function doesn't do any validation,
+/// the given parameters. Note that this function doesn't do any validation,
/// it's expected that the provided parameters correspond to a valid
/// instruction.
-static std::string
-buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
- unsigned matBSize, unsigned matCSize,
- NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
- NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
- std::optional<NVVM::MMAIntOverflow> overflow) {
+static std::string buildMmaSparseAsmString(
+ const std::array<int64_t, 3> &shape, unsigned matASize, unsigned matBSize,
+ unsigned matCSize, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
+ NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
+ std::optional<NVVM::MMAIntOverflow> overflow, unsigned metaDataSelector) {
auto ptxTypeStr = [](NVVM::MMATypes ptxType) {
return NVVM::stringifyMMATypes(ptxType);
};
@@ -442,7 +442,8 @@ buildMmaSparseAsmString(const std::array<int64_t, 3> &shape, unsigned matASize,
ss << "},";
}
ss << "$" << asmArgIdx++ << ",";
- ss << "$" << asmArgIdx++ << ";";
+ assert(metaDataSelector <= 1);
+ ss << "0x" << metaDataSelector << ";";
ss.flush();
return asmStr;
}
@@ -459,22 +460,21 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
- std::string asmStr = buildMmaSparseAsmString(
- shape, unpackedAData.size(), unpackedB.size(), unpackedC.size(), ptxTypeA,
- ptxTypeB, ptxTypeC, ptxTypeD, overflow);
- std::string constraintStr = buildMmaSparseAsmConstraintString(
- unpackedAData.size(), unpackedB.size(), unpackedC.size());
+ const unsigned matASize = unpackedAData.size();
+ const unsigned matBSize = unpackedB.size();
+ const unsigned matCSize = unpackedC.size();
- Value selectorVal = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(metadataSelector));
+ std::string asmStr = buildMmaSparseAsmString(
+ shape, matASize, matBSize, matCSize, ptxTypeA, ptxTypeB, ptxTypeC,
+ ptxTypeD, overflow, metadataSelector);
+ std::string constraintStr =
+ buildMmaSparseAsmConstraintString(matASize, matBSize, matCSize);
SmallVector<Value> asmVals;
- asmVals.reserve(unpackedAData.size() + unpackedB.size() + unpackedC.size() +
- 2);
+ asmVals.reserve(matASize + matBSize + matCSize + 1);
for (ArrayRef<Value> args : {unpackedAData, unpackedB, unpackedC})
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
- asmVals.push_back(selectorVal);
return rewriter.create<LLVM::InlineAsmOp>(loc,
/*resultTypes=*/intrinsicResultType,
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index f4125847230de..1133e09c538fa 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -333,12 +333,11 @@ func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
// CHECK-NOT llvm.extractvalue
// CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
- // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,$13;"
- // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r,r"
- // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3,$4,$5},{$6,$7,$8,$9},{$10,$11},$12,0x0;"
+ // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 32]} :
@@ -372,12 +371,11 @@ func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>,
// CHECK-NOT llvm.extractvalue
// CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
- // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,$9;"
- // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r,r"
- // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,0x0;"
+ // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16]} :
@@ -413,12 +411,11 @@ func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>,
// CHECK-NOT llvm.extractvalue
// CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32
- // CHECK: %[[sparseSelector:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att
- // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32
- // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r,r"
- // CHECK-SAME: %[[sparseMetadata]], %[[sparseSelector]] :
+ // CHECK-SAME: "mma.sp.sync.aligned.m16n8k64.row.col.satfinite.s32.s8.s8.s32 {$0,$1,$2,$3},{$4,$5,$6,$7},{$8,$9,$10,$11},{$12,$13,$14,$15},$16,0x0;"
+ // CHECK-SAME: "=r,=r,=r,=r,r,r,r,r,r,r,r,r,r,r,r,r,r"
+ // CHECK-SAME: %[[sparseMetadata]] :
// CHECK-SAME: -> !llvm.struct<(i32, i32, i32, i32)
%d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 64]} :
More information about the Mlir-commits
mailing list