[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Dec 1 09:34:28 PST 2025
================
@@ -653,23 +653,57 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
-/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
-/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
+/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
-/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
+ Value input) {
Type inputType = input.getType();
- Type outputType = rewriter.getI32Type();
+
+ // Handle scalar i8: zero extend to i32.
if (auto intType = dyn_cast<IntegerType>(inputType))
- return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
- return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
+
+ // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
+ if (auto vectorType = dyn_cast<VectorType>(inputType)) {
+ int64_t numElements = vectorType.getNumElements();
+ IntegerType outputType =
+ (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+ }
+
+ llvm_unreachable("unexpected input type for scale operand");
+}
+
+/// Maps f8 scale element types to WMMA scale format codes.
+static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
+ return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
+ .Case([](Float8E8M0FNUType) { return 0; })
+ .Case([](Float8E4M3FNType) { return 2; })
----------------
krzysz00 wrote:
... Oh, well, if this has been assigned a code, then it works, cool!
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list