[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 30 06:26:26 PST 2025


================
@@ -653,23 +653,33 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
   return input;
 }
 
-/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
-/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
+/// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR
+/// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
 ///
 /// Specifically:
 /// 1. If `input` is a i8 value, zero extend it to i32
-/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
+/// 2. If `input` is a vector of length 4 or 8 and type i8, cast it to i32
 ///
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
-                                  Location loc, Value input) {
+static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
+                              Value input) {
   Type inputType = input.getType();
-  Type outputType = rewriter.getI32Type();
+
+  // Handle scalar i8: zero extend to i32
   if (auto intType = dyn_cast<IntegerType>(inputType))
-    return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
-  return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
+    return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
+
+  // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64
----------------
kuhar wrote:

```suggestion
  // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
```

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list