[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)

Sang Ik Lee llvmlistbot at llvm.org
Wed Aug 20 10:55:35 PDT 2025


================
@@ -0,0 +1,926 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+  BasePtr = 0,
+  BaseShapeW = 2,
+  BaseShapeH = 3,
+  TensorOffsetW = 4,
+  TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+  switch (xeGpuMemspace) {
+  case xegpu::MemorySpace::Global:
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
+  case xegpu::MemorySpace::SLM:
+    return static_cast<int>(xevm::AddrSpace::SHARED);
+  }
+  llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+template <typename T>
+std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
+  assert(!denseAttr.empty());
+  const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
+  if (denseAttr.size() < 2)
+    return {true, 0, intercept};
+  const T slope{denseAttr[1] - denseAttr[0]};
+  for (size_t i = 1; i < denseAttr.size(); ++i)
+    if (denseAttr[i] - denseAttr[i - 1] != slope)
+      return {false, 0, 0};
+  return {true, static_cast<int32_t>(slope), intercept};
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+  auto elemType = currentVecType.getElementType();
+  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+  const int size =
+      currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+  return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                            std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::CACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::READ_INVALIDATE:
+    return xevm::LoadCacheControl::INVALIDATE_READ;
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                             std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_BACK:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_THROUGH:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+class CreateNdDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op,
+                  xegpu::CreateNdDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto source = op.getSource();
+    Type payloadElemTy = rewriter.getI32Type();
+    Type i64Ty = rewriter.getI64Type();
+    VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    Value payload = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+    Value baseAddr;
+    Value baseShapeW;
+    Value baseShapeH;
+    Value offsetW;
+    Value offsetH;
+
+    bool sourceIsMemref = false;
+    auto sourceTy = source.getType();
+    int64_t rank;
+    if (isa<MemRefType>(sourceTy)) {
+      sourceIsMemref = true;
+      baseAddr =
+          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      if (!sourceMemrefTy.hasStaticShape()) {
+        op.emitError() << "Expected static memref shape.";
+        return failure();
+      }
+      rank = sourceMemrefTy.getRank();
+      if (rank != 2) {
+        op.emitError() << "Expected a 2D memref.";
+        return failure();
+      }
+    } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+      rank = op.getMixedSizes().size();
+    } else {
+      op.emitError() << "Expected source to be a 2D memref or ui64.";
+      return failure();
+    }
+    auto createOffset = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedOffsets()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    auto offsets = op.getMixedOffsets();
+    if (offsets.size() == 2) {
+      offsetW = createOffset(rank - 1);
+      offsetH = createOffset(rank - 2);
+    } else {
+      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    }
+    auto createShape = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedSizes()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    if (sourceIsMemref) {
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      baseShapeW = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+      baseShapeH = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+    } else {
+      baseShapeW = createShape(rank - 1);
+      baseShapeH = createShape(rank - 2);
+      baseAddr = adaptor.getSource();
+    }
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+    payLoadAsI64 =
+        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+                                 static_cast<int>(NdDescI32Layout::BasePtr));
+    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeW));
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeH));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetW, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetH, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    rewriter.replaceOp(op, payload);
+    return success();
+  }
+};
+
+class UpdateNdOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = op.getOffsets();
+    auto tdesc = adaptor.getTensorDesc();
+    for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+      auto offset = offsets[offsetDim];
+      if (auto cst =
+              dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+        if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+            attr && !attr.getInt())
+          continue;
+      const int offsetPos =
+          static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+                                     : NdDescI32Layout::TensorOffsetH);
+      auto oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+      offset = arith::IndexCastUIOp::create(rewriter, loc,
+                                            rewriter.getI32Type(), offset);
+      auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      tdesc =
+          vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+    }
+    rewriter.replaceOp(op, tdesc);
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+
+    auto tdesc = adaptor.getTensorDesc();
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+    }
+
+    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+    Value basePtr =
+        vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+                                  static_cast<int>(NdDescI32Layout::BasePtr));
+    Value baseShapeW = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+    Value baseShapeH = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+    // Offsets can come from three sources:
+    // 1. Constant offsets, which are provided by the op.
+    // 2. Offsets as operands, which are provided by the op.
+    // 3. Offsets extracted from the tensor descriptor.
+    Value offsetW;
+    Value offsetH;
+    auto cOffsets = op.getConstOffsets();
+    auto offsets = op.getOffsets();
+    if (cOffsets) {
+      offsetW = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+      offsetH = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+    } else if (offsets.size() != 0) {
+      // offsets are provided as operands
+      if (offsets[0].getType() != rewriter.getI32Type()) {
+        if (offsets[0].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetW = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[0]);
+      } else {
+        offsetW = offsets[0];
+      }
+      if (offsets[1].getType() != rewriter.getI32Type()) {
+        if (offsets[1].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetH = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[1]);
+      } else {
+        offsetH = offsets[1];
+      }
+    } else {
+      // If offsets are not available, we need to extract them from the tensor
+      // descriptor.
+      offsetW = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetW));
+      offsetH = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    }
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    auto elemType = tdescTy.getElementType();
+    auto elemBitSize = elemType.getIntOrFloatBitWidth();
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    Value surfaceW =
+        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+    auto tileW = tdescTy.getDimSize(1);
+    auto tileH = tdescTy.getDimSize(0);
+    int32_t vblocks = tdescTy.getArrayLength();
+    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+      VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      VectorType srcFlatVecTy =
+          VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
+      Value srcFlatVec = op.getValue();
+      srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
+                                        rewriter.getIntegerType(elemBitSize));
+      srcFlatVec =
+          vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+      xevm::BlockStore2dOp::create(
+          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+          offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      rewriter.eraseOp(op);
+    } else {
+      auto loadCacheControl =
+          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+        xevm::BlockPrefetch2dOp::create(
+            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+            offsetH, elemBitSize, tileW, tileH, vblocks,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        rewriter.eraseOp(op);
+      } else {
+        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+        const bool vnni = op.getPacked().value_or(false);
+        auto transposeValue = op.getTranspose();
+        bool transpose =
+            transposeValue.has_value() && transposeValue.value()[0] == 1;
+        VectorType loadedTy = encodeVectorTypeTo(
+            dstVecTy, vnni ? rewriter.getI32Type()
+                           : rewriter.getIntegerType(elemBitSize));
+
+        Value resultFlatVec = xevm::BlockLoad2dOp::create(
+            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+            transpose, vnni,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        resultFlatVec = vector::BitCastOp::create(
+            rewriter, loc,
+            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+            resultFlatVec);
+        rewriter.replaceOp(op, resultFlatVec);
+      }
+    }
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
+        xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
+int64_t getElemByteSize(OpType op) {
+  // Get the element byte size from the tensor descriptor.
+  auto elemBitWidth =
+      op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
+  return elemBitWidth / 8;
+}
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+                    Value baseAddr, Value offset,
+                    int64_t elemByteSize) -> Value {
+  Value byteSize = arith::ConstantIntOp::create(
+      rewriter, loc, rewriter.getI64Type(), elemByteSize);
+  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+  return newAddr;
+};
+
+class CreateDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = adaptor.getOffsets();
----------------
silee2 wrote:

That part of the description needs to be updated.
vector of offsets fit subgroup level, but after SIMT distribution,
ops consuming the descriptor are no longer cooperative and becomes per lane regular load / store / prefetch.
vector<1xindex> fits that abstraction better.

https://github.com/llvm/llvm-project/pull/154556


More information about the Mlir-commits mailing list