[Mlir-commits] [mlir] [mlir][AMDGPU] Add int4 intrinsics, mixed-type fp8 to handle gfx12 (PR #128963)
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Feb 27 09:54:38 PST 2025
================
@@ -427,25 +434,33 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
// fp8/int8 information is lost during the conversion process.
auto mlirInputType = cast<VectorType>(mlirInput.getType());
- bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
- if (isInputInt8) {
+ bool isInputInteger = mlirInputType.getElementType().isInteger();
+ if (isInputInteger) {
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
- if (elemType.isUnsignedInteger(8)) {
+ if (elemType.isUnsignedInteger()) {
localIsUnsigned = true;
- } else if (elemType.isSignedInteger(8)) {
+ } else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
}
- int64_t numBytes = vectorType.getNumElements();
+ int64_t numBits =
+ vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
Type i32 = rewriter.getI32Type();
- VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
- auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
+ Type intrinsicInType = numBits <= 32
+ ? (Type)rewriter.getIntegerType(numBits)
+ : (Type)VectorType::get(numBits / 32, i32);
+ auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
Value result = rewriter.createOrFold<LLVM::BitcastOp>(
----------------
krzysz00 wrote:
`result` of the function, but yeah, could be `argument`
https://github.com/llvm/llvm-project/pull/128963
More information about the Mlir-commits
mailing list