[Mlir-commits] [mlir] [mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 (PR #128963)
Daniel Hernandez-Juarez
llvmlistbot at llvm.org
Thu Feb 27 09:05:14 PST 2025
================
@@ -631,10 +649,33 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
+ if (chipset.majorVersion == 11) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ }
+ if (chipset.majorVersion >= 12) {
+ if (isa<Float8E4M3FNType>(elemSourceType) &&
+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
+ if (isa<Float8E4M3FNType>(elemSourceType) &&
+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
+ if (isa<Float8E5M2Type>(elemSourceType) &&
+ isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
+ if (isa<Float8E5M2Type>(elemSourceType) &&
+ isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
+ bool isWave64 = destVectorType.getNumElements() == 4;
+ // This is the ambiguous case. 8 inputs to the wave64 version means that
+ // we want the 16x16x32 version, but for wave32 they mean the short form.
+ bool has8Inputs = sourceVectorType.getNumElements() == 8;
+ if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
----------------
dhernandez0 wrote:
nit: if(isWave64 == has8Inputs)
https://github.com/llvm/llvm-project/pull/128963
More information about the Mlir-commits
mailing list