[Mlir-commits] [mlir] [MLIR][XeGPU][XeVM] Add lowering for load_nd with MX scale element type. (PR #198020)

Artem Kroviakov llvmlistbot at llvm.org
Wed May 27 06:36:36 PDT 2026


================
@@ -266,6 +266,124 @@ class CreateNdDescToXeVMPattern
   }
 };
 
+class LoadNdMXScaleToXeVMPattern : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, xegpu::LoadNdOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // load_nd ops for MX A Scale and B Scale are not directly supported by
+    // Intel HW's 2D block load. Lower this special case to an alternative
+    // code sequence using simple loads.
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getElementType() != rewriter.getF8E8M0Type())
+      return failure();
+    if (tdescTy.getRank() != 2)
+      return failure();
+    // Supported tile shapes are
+    // 8x1 or 8x2 for A Scale and
+    // 1x16 or 2x16 for B Scale
+    auto tileShape = tdescTy.getShape();
+    bool isAScale = tileShape[0] == 8;
+    bool isBScale = tileShape[1] == 16;
+    if (!isAScale && !isBScale)
+      return failure();
+    Type resTy = getTypeConverter()->convertType(op.getValue().getType());
+    auto resVecTy = dyn_cast<VectorType>(resTy);
+    if (resVecTy) {
+      if (resVecTy.getRank() != 1)
+        return failure();
+      if (resVecTy.getShape()[0] != 2)
----------------
akroviakov wrote:

So we can only load 2 e8m0 scales?
Would be nice to have a reference link, explaining this restriction.

https://github.com/llvm/llvm-project/pull/198020


More information about the Mlir-commits mailing list