[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Add lowering for load_nd with MX scale element type. (PR #198020)
Artem Kroviakov
llvmlistbot at llvm.org
Wed May 27 06:36:36 PDT 2026
================
@@ -266,6 +266,124 @@ class CreateNdDescToXeVMPattern
}
};
+class LoadNdMXScaleToXeVMPattern : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, xegpu::LoadNdOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // load_nd ops for MX A Scale and B Scale are not directly supported by
+ // Intel HW's 2D block load. Lower this special case to an alternative
+ // code sequence using simple loads.
+ auto tdescTy = op.getTensorDescType();
+ if (tdescTy.getElementType() != rewriter.getF8E8M0Type())
+ return failure();
+ if (tdescTy.getRank() != 2)
+ return failure();
+ // Supported tile shapes are
+ // 8x1 or 8x2 for A Scale and
+ // 1x16 or 2x16 for B Scale
+ auto tileShape = tdescTy.getShape();
+ bool isAScale = tileShape[0] == 8;
+ bool isBScale = tileShape[1] == 16;
+ if (!isAScale && !isBScale)
+ return failure();
+ Type resTy = getTypeConverter()->convertType(op.getValue().getType());
+ auto resVecTy = dyn_cast<VectorType>(resTy);
+ if (resVecTy) {
+ if (resVecTy.getRank() != 1)
+ return failure();
+ if (resVecTy.getShape()[0] != 2)
----------------
akroviakov wrote:
So we can only load 2 e8m0 scales?
Would be nice to have a reference link, explaining this restriction.
https://github.com/llvm/llvm-project/pull/198020
More information about the Mlir-commits
mailing list