[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Aug 27 10:51:14 PDT 2025


================
@@ -0,0 +1,1019 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+  BasePtr = 0,       // Base pointer (i64)
+  BaseShapeW = 2,    // Base shape width (i32)
+  BaseShapeH = 3,    // Base shape height (i32)
+  TensorOffsetW = 4, // Tensor offset W (i32)
+  TensorOffsetH = 5  // Tensor offset H (i32)
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+  switch (xeGpuMemspace) {
+  case xegpu::MemorySpace::Global:
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
+  case xegpu::MemorySpace::SLM:
+    return static_cast<int>(xevm::AddrSpace::SHARED);
+  }
+}
+
+// Get same bitwidth flat vector type of new element type.
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+                                     Type toElemType) {
+  auto elemType = currentVecType.getElementType();
+  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+  const int size =
+      currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+  return VectorType::get(size, toElemType);
+}
+
+static xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                            std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::CACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::READ_INVALIDATE:
+    return xevm::LoadCacheControl::INVALIDATE_READ;
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+static xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                             std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_BACK:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_THROUGH:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+class CreateNdDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op,
+                  xegpu::CreateNdDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto source = op.getSource();
+    // Op is lowered to a code sequence that populates payload.
+    // Payload is a 8xi32 vector. Offset to individual fields are defined in
+    // NdTdescOffset enum.
+    Type payloadElemTy = rewriter.getI32Type();
+    VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    Type i64Ty = rewriter.getI64Type();
+    // 4xi64 view is used for inserting the base pointer.
+    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    // Initialize payload to zero.
+    Value payload = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+    Value baseAddr;
+    Value baseShapeW;
+    Value baseShapeH;
+    Value offsetW;
+    Value offsetH;
+
+    // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+    SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    // Descriptor shape is expected to be 2D.
+    int64_t rank = mixedSizes.size();
+    if (rank != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+    auto sourceTy = source.getType();
+    auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+    // If source is a memref, we need to extract the aligned pointer as index.
+    // Pointer type is passed as i32 or i64 by type converter.
+    if (sourceMemrefTy) {
+      if (!sourceMemrefTy.hasStaticShape()) {
+        op.emitError() << "Expected static memref shape.";
+        return failure();
+      }
+      baseAddr =
+          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+    } else {
+      baseAddr = adaptor.getSource();
+    }
+    // Utility for creating offset values from op fold result.
+    auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
+                            unsigned idx) -> Value {
+      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+      val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
+      return val;
+    };
+    // Offsets can be either 2D or not provided (0 is used).
+    if (mixedOffsets.size() == 2) {
+      offsetW = createOffset(mixedOffsets, 1);
+      offsetH = createOffset(mixedOffsets, 0);
+    } else if (mixedOffsets.size() == 0) {
+      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    } else {
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected 2D offsets or no offsets.");
+    }
+    // Get shape values from op fold results.
+    baseShapeW = createOffset(mixedSizes, 1);
+    baseShapeH = createOffset(mixedSizes, 0);
+    if (sourceMemrefTy)
+      // Cast index to i64.
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+    else if (baseAddr.getType() != i64Ty)
+      // Pointer type may be i32. Cast to i64 if needed.
+      baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+
+    // Populate payload.
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+    payLoadAsI64 =
+        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+                                 static_cast<int>(NdTdescOffset::BasePtr));
+    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+                                 static_cast<int>(NdTdescOffset::BaseShapeW));
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+                                 static_cast<int>(NdTdescOffset::BaseShapeH));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetW, payload,
+        static_cast<int>(NdTdescOffset::TensorOffsetW));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetH, payload,
+        static_cast<int>(NdTdescOffset::TensorOffsetH));
+    rewriter.replaceOp(op, payload);
+    return success();
+  }
+};
+
+class UpdateNdOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto mixedOffsets = op.getMixedOffsets();
+    // Only 2D offsets are supported for now.
+    if (mixedOffsets.size() != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
+    auto tdesc = adaptor.getTensorDesc();
+    // Utility for updating payload offset values from op fold result.
+    auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
+      Value offset =
+          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
+      offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                               rewriter.getI32Type(), offset);
+      Value oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+      Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+                                      payloadPos);
+    };
+    // Update offsets in the payload.
+    auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+    rewriter.replaceOp(op, val);
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+
+    auto tdesc = adaptor.getTensorDesc();
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+    auto elemType = tdescTy.getElementType();
+    auto elemBitSize = elemType.getIntOrFloatBitWidth();
+    if (elemBitSize % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+
+    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+    Value basePtr = vector::ExtractOp::create(
+        rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
+    Value baseShapeW = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+    Value baseShapeH = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+    // Offsets provided in two ways:
+    // 1. Offsets are extracted from the tensor descriptor.
+    // 2. (Mixed) offsets which are provided by the op.
+    Value offsetW;
+    Value offsetH;
+    auto mixedOffsets = op.getMixedOffsets();
+    int64_t opOffsetsSize = mixedOffsets.size();
+    if (opOffsetsSize != 0 && opOffsetsSize != 2)
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected 2D offsets or no offsets.");
+    if (opOffsetsSize) {
+      // If mixed offsets are provided by the op convert them to i32.
+      offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+      offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetW);
+      offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+      offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetH);
+    } else {
+      // If offsets are not available, we need to extract them from the tensor
+      // descriptor.
+      offsetW = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
+      offsetH = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    }
+    // Get address space from tensor descriptor memory space.
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    // Convert base pointer (i64) to LLVM pointer type.
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    // Compute element byte size and surface width in bytes.
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    Value surfaceW =
+        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+    // Get tile sizes and vblocks from the tensor descriptor type.
+    auto tileW = tdescTy.getDimSize(1);
+    auto tileH = tdescTy.getDimSize(0);
+    int32_t vblocks = tdescTy.getArrayLength();
+    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+      VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+      if (!srcVecTy)
+        return rewriter.notifyMatchFailure(
+            op, "Expected store value to be a vector type.");
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      Value src = adaptor.getValue();
+      // Get flat vector type of integer type with matching element bit size.
+      VectorType newSrcVecTy =
+          encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+      if (srcVecTy != newSrcVecTy)
+        src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+      xevm::BlockStore2dOp::create(
+          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+          offsetH, elemBitSize, tileW, tileH, src,
+          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      rewriter.eraseOp(op);
+    } else {
+      auto loadCacheControl =
+          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+        xevm::BlockPrefetch2dOp::create(
+            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+            offsetH, elemBitSize, tileW, tileH, vblocks,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        rewriter.eraseOp(op);
+      } else {
+        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+        const bool vnni = op.getPacked().value_or(false);
+        auto transposeValue = op.getTranspose();
+        bool transpose =
+            transposeValue.has_value() && transposeValue.value()[0] == 1;
+        VectorType loadedTy = encodeVectorTypeTo(
+            dstVecTy, vnni ? rewriter.getI32Type()
+                           : rewriter.getIntegerType(elemBitSize));
+
+        Value resultFlatVec = xevm::BlockLoad2dOp::create(
+            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+            transpose, vnni,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        resultFlatVec = vector::BitCastOp::create(
+            rewriter, loc,
+            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+            resultFlatVec);
+        rewriter.replaceOp(op, resultFlatVec);
+      }
+    }
+    return success();
+  }
+};
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
+                       Value baseAddr, Value offset, int64_t elemByteSize) {
+  Value byteSize = arith::ConstantIntOp::create(
+      rewriter, loc, rewriter.getI64Type(), elemByteSize);
+  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+  return newAddr;
+}
+
+class CreateDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto eTy = op.getTensorDescType().getElementType();
+    auto eBw = eTy.getIntOrFloatBitWidth();
+    if (eBw % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    auto loc = op.getLoc();
+    // Offsets are provided as scalar i64 by type converter.
+    auto offsets = adaptor.getOffsets();
+    // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
+    // But type converter will convert them to integer types.
+    Value addr = adaptor.getSource();
+    // ui32 or i32 are passed as i32 so they need to be casted to i64.
+    if (addr.getType() != rewriter.getI64Type())
+      addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
+    auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
+    rewriter.replaceOp(op, laneAddr);
+    return success();
+  }
+};
+
+class UpdateOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateOffsetOp op,
+                  xegpu::UpdateOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto eTy = op.getTensorDescType().getElementType();
+    auto eBw = eTy.getIntOrFloatBitWidth();
+    if (eBw % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    auto loc = op.getLoc();
+    // Scatter descriptor is provided as scalar i64 by type converter.
+    // Offsets are provided as scalar i64 by type converter.
+    Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
+                                adaptor.getOffsets(), eBw / 8);
+    rewriter.replaceOp(op, newOffset);
+    return success();
+  }
+};
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdescTy = op.getTensorDescType();
+    Value basePtrI64;
+    // Load result or Store valye Type can be vector or scalar.
+    Type valOrResTy;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
+      valOrResTy = op.getResult().getType();
+    else
+      valOrResTy = adaptor.getValue().getType();
+    VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
+    bool hasScalarVal = !valOrResVecTy;
+    int64_t elemBitWidth =
+        hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
+                     : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    // Element type must be multiple of 8 bits.
+    if (elemBitWidth % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    int64_t elemByteSize = elemBitWidth / 8;
+    // Default memory space is global.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+    // If tensor descriptor is available, we use its memory space.
+    if (tdescTy)
+      ptrTypeLLVM = LLVM::LLVMPointerType::get(
+          ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    // Base pointer can come from source (load) or dest (store).
+    // If they are memrefs, we use their memory space.
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      basePtrI64 = adaptor.getSource();
+      if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+        auto addrSpace = memRefTy.getMemorySpaceAsInt();
+        if (addrSpace != 0)
+          ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+      }
+    } else {
+      basePtrI64 = adaptor.getDest();
+      if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
+        auto addrSpace = memRefTy.getMemorySpaceAsInt();
+        if (addrSpace != 0)
+          ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+      }
+    }
+    // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
+    if (basePtrI64.getType() != rewriter.getI64Type()) {
+      basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                          basePtrI64);
+    }
+    Value offsets = adaptor.getOffsets();
+    Value mask = adaptor.getMask();
+    if (offsets) {
+      if (dyn_cast<VectorType>(offsets.getType()))
+        // Offset needs be scalar. Single element vector is converted to scalar
+        // by type converter.
+        return rewriter.notifyMatchFailure(op,
+                                           "Expected offsets to be a scalar.");
+      else
+        // If offsets are provided, we add them to the base pointer.
+        // Offsets are in number of elements, we need to multiply by
+        // element byte size.
+        basePtrI64 =
----------------
adam-smnk wrote:

braces

https://github.com/llvm/llvm-project/pull/154556


More information about the Mlir-commits mailing list