[Mlir-commits] [mlir] [MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (PR #162780)

Jianhui Li llvmlistbot at llvm.org
Tue Oct 14 08:25:41 PDT 2025


================
@@ -503,6 +506,187 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
   }
 };
 
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+    : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+  using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    TypedValue<MemRefType> src = op.getSource();
+    auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+    // Create the result MemRefType with the same shape, element type, and
+    // memory space
+    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+                                         Value(src), zero, ValueRange());
+    rewriter.replaceOp(op, viewOp);
+    return success();
+  }
+};
+
+class MemDescSubviewOpPattern final
+    : public OpConversionPattern<xegpu::MemDescSubviewOp> {
+public:
+  using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return rewriter.notifyMatchFailure(
+        op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
+  }
+};
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    Value basePtrStruct = adaptor.getMemDesc();
+    Value mdescVal = op.getMemDesc();
+    // Load result or Store value Type can be vector or scalar.
+    Value data;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+      data = op.getResult();
+    else
+      data = adaptor.getData();
+    VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+    if (!valOrResVecTy)
+      valOrResVecTy = VectorType::get(1, data.getType());
+
+    int64_t elemBitWidth =
+        valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    // Element type must be multiple of 8 bits.
+    if (elemBitWidth % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    int64_t elemByteSize = elemBitWidth / 8;
+
+    // Default memory space is SLM.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+    auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, basePtrStruct);
+
+    // Convert base pointer (ptr) to i64
+    Value basePtrI64 = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+
+    Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+    linearOffset = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), linearOffset);
+    basePtrI64 =
+        addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+
+    // convert base pointer (i64) to LLVM pointer type
----------------
Jianhui-Li wrote:

changed to i32

https://github.com/llvm/llvm-project/pull/162780


More information about the Mlir-commits mailing list