[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 1 06:26:53 PST 2025


================
@@ -667,21 +667,44 @@ static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
                               Value input) {
   Type inputType = input.getType();
 
-  // Handle scalar i8: zero extend to i32
+  // Handle scalar i8: zero extend to i32.
   if (auto intType = dyn_cast<IntegerType>(inputType))
     return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), input);
 
-  // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64
+  // Handle vector<4xi8> -> i32 or vector<8xi8> -> i64.
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     int64_t numElements = vectorType.getNumElements();
-    Type outputType = (numElements == 4) ? (Type)rewriter.getI32Type()
-                                         : (Type)rewriter.getI64Type();
+    IntegerType outputType =
+        (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
     return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
   }
 
   llvm_unreachable("unexpected input type for scale operand");
 }
 
+/// Maps f8 scale element types to WMMA scale format codes.
+static std::optional<uint32_t> getWmmaScaleFormat(Type elemType) {
+  return TypeSwitch<Type, std::optional<uint32_t>>(elemType)
+      .Case<Float8E8M0FNUType>([](auto) { return 0; })
+      .Case<Float8E4M3FNType>([](auto) { return 2; })
+      .Default([](Type) { return std::nullopt; });
----------------
kuhar wrote:

```suggestion
      .Case([](Float8E8M0FNUType) { return 0; })
      .Case([](Float8E4M3FNType) { return 2; })
      .Default(std::nullopt);
```

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list