[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Nov 30 06:26:26 PST 2025
================
@@ -653,23 +653,33 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}
-/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
-/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
+/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
-/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
+ Value input) {
Type inputType = input.getType();
- Type outputType = rewriter.getI32Type();
+
+ // Handle scalar i8: zero extend to i32
if (auto intType = dyn_cast<IntegerType>(inputType))
- return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
- return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
+
+ // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64
----------------
kuhar wrote:
```suggestion
// Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
```
https://github.com/llvm/llvm-project/pull/169854
More information about the Mlir-commits
mailing list