[Mlir-commits] [mlir] [mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp (PR #155951)
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed Sep 17 12:50:13 PDT 2025
================
@@ -631,6 +633,148 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // If this use of a scale has a non zero opsel, packing has already been
+ // done.
+ auto checkIfUnpackable = [&](OpOperand &op) {
+ if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+ switch (op.getOperandNumber()) {
+ case 3:
+ return smfma.getScalesIdxA() != 0;
+ break;
+ case 4:
+ return smfma.getScalesIdxB() != 0;
+ break;
+ default:
+ return true;
+ break;
+ }
+ }
+ };
+
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ return op.setScalesIdxA(val);
+ break;
+ case 4:
+ return op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // Obtain flat index from offsets and shape.
+ auto getIdxFromExtract = [](vector::ExtractOp op) {
+ ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+ int cumul = 1;
----------------
krzysz00 wrote:
int64_t for things holding shapes
https://github.com/llvm/llvm-project/pull/155951
More information about the Mlir-commits
mailing list