[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Aug 21 08:25:15 PDT 2025


================
@@ -0,0 +1,913 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+  BasePtr = 0,
+  BaseShapeW = 2,
+  BaseShapeH = 3,
+  TensorOffsetW = 4,
+  TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+  switch (xeGpuMemspace) {
+  case xegpu::MemorySpace::Global:
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
+  case xegpu::MemorySpace::SLM:
+    return static_cast<int>(xevm::AddrSpace::SHARED);
+  }
+  llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+  auto elemType = currentVecType.getElementType();
+  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+  const int size =
+      currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+  return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                            std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::CACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::READ_INVALIDATE:
+    return xevm::LoadCacheControl::INVALIDATE_READ;
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                             std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_BACK:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_THROUGH:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+class CreateNdDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op,
+                  xegpu::CreateNdDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto source = op.getSource();
+    Type payloadElemTy = rewriter.getI32Type();
+    Type i64Ty = rewriter.getI64Type();
+    VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    Value payload = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+    Value baseAddr;
+    Value baseShapeW;
+    Value baseShapeH;
+    Value offsetW;
+    Value offsetH;
+
+    bool sourceIsMemref = false;
+    auto sourceTy = source.getType();
+    int64_t rank;
+    if (isa<MemRefType>(sourceTy)) {
+      sourceIsMemref = true;
+      baseAddr =
+          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      if (!sourceMemrefTy.hasStaticShape()) {
+        op.emitError() << "Expected static memref shape.";
+        return failure();
+      }
+      rank = sourceMemrefTy.getRank();
+      if (rank != 2) {
+        op.emitError() << "Expected a 2D memref.";
+        return failure();
+      }
+    } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+      rank = op.getMixedSizes().size();
+    } else {
+      op.emitError() << "Expected source to be a 2D memref or ui64.";
+      return failure();
+    }
+    auto createOffset = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedOffsets()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    auto offsets = op.getMixedOffsets();
+    if (offsets.size() == 2) {
+      offsetW = createOffset(rank - 1);
+      offsetH = createOffset(rank - 2);
+    } else {
+      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    }
+    auto createShape = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedSizes()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    if (sourceIsMemref) {
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      baseShapeW = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+      baseShapeH = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+    } else {
+      baseShapeW = createShape(rank - 1);
+      baseShapeH = createShape(rank - 2);
+      baseAddr = adaptor.getSource();
+    }
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+    payLoadAsI64 =
+        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+                                 static_cast<int>(NdDescI32Layout::BasePtr));
+    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeW));
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeH));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetW, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetH, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    rewriter.replaceOp(op, payload);
+    return success();
+  }
+};
+
+class UpdateNdOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = op.getOffsets();
+    auto tdesc = adaptor.getTensorDesc();
+    for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+      auto offset = offsets[offsetDim];
+      if (auto cst =
+              dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+        if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+            attr && !attr.getInt())
+          continue;
+      const int offsetPos =
+          static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+                                     : NdDescI32Layout::TensorOffsetH);
+      auto oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+      offset = arith::IndexCastUIOp::create(rewriter, loc,
+                                            rewriter.getI32Type(), offset);
+      auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      tdesc =
+          vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+    }
+    rewriter.replaceOp(op, tdesc);
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+
+    auto tdesc = adaptor.getTensorDesc();
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+    }
+
+    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+    Value basePtr =
+        vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+                                  static_cast<int>(NdDescI32Layout::BasePtr));
+    Value baseShapeW = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+    Value baseShapeH = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+    // Offsets can come from three sources:
+    // 1. Constant offsets, which are provided by the op.
+    // 2. Offsets as operands, which are provided by the op.
+    // 3. Offsets extracted from the tensor descriptor.
+    Value offsetW;
+    Value offsetH;
+    auto cOffsets = op.getConstOffsets();
+    auto offsets = op.getOffsets();
+    if (cOffsets) {
+      offsetW = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+      offsetH = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+    } else if (offsets.size() != 0) {
+      // offsets are provided as operands
+      if (offsets[0].getType() != rewriter.getI32Type()) {
+        if (offsets[0].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetW = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[0]);
+      } else {
+        offsetW = offsets[0];
+      }
+      if (offsets[1].getType() != rewriter.getI32Type()) {
+        if (offsets[1].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetH = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[1]);
+      } else {
+        offsetH = offsets[1];
+      }
+    } else {
+      // If offsets are not available, we need to extract them from the tensor
+      // descriptor.
+      offsetW = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetW));
+      offsetH = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    }
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    auto elemType = tdescTy.getElementType();
+    auto elemBitSize = elemType.getIntOrFloatBitWidth();
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    Value surfaceW =
+        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+    auto tileW = tdescTy.getDimSize(1);
+    auto tileH = tdescTy.getDimSize(0);
+    int32_t vblocks = tdescTy.getArrayLength();
+    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+      VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      VectorType srcFlatVecTy =
+          VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
+      Value srcFlatVec = op.getValue();
+      srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
+                                        rewriter.getIntegerType(elemBitSize));
+      srcFlatVec =
+          vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+      xevm::BlockStore2dOp::create(
+          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+          offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      rewriter.eraseOp(op);
+    } else {
+      auto loadCacheControl =
+          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+        xevm::BlockPrefetch2dOp::create(
+            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+            offsetH, elemBitSize, tileW, tileH, vblocks,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        rewriter.eraseOp(op);
+      } else {
+        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+        const bool vnni = op.getPacked().value_or(false);
+        auto transposeValue = op.getTranspose();
+        bool transpose =
+            transposeValue.has_value() && transposeValue.value()[0] == 1;
+        VectorType loadedTy = encodeVectorTypeTo(
+            dstVecTy, vnni ? rewriter.getI32Type()
+                           : rewriter.getIntegerType(elemBitSize));
+
+        Value resultFlatVec = xevm::BlockLoad2dOp::create(
+            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+            transpose, vnni,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        resultFlatVec = vector::BitCastOp::create(
+            rewriter, loc,
+            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+            resultFlatVec);
+        rewriter.replaceOp(op, resultFlatVec);
+      }
+    }
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
+        xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
+int64_t getElemByteSize(OpType op) {
+  // Get the element byte size from the tensor descriptor.
+  auto elemBitWidth =
+      op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
+  return elemBitWidth / 8;
+}
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+                    Value baseAddr, Value offset,
+                    int64_t elemByteSize) -> Value {
+  Value byteSize = arith::ConstantIntOp::create(
+      rewriter, loc, rewriter.getI64Type(), elemByteSize);
+  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+  return newAddr;
+};
+
+class CreateDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = adaptor.getOffsets();
+    // Source type can be a 1D memref or ui64
+    // Using "op" instead of "adaptor" since we want to access memref type
+    // instead of LLVM struct type.
+    auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
+    Value subGroupAddr;
+    if (memrefTy) {
+      subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+          rewriter, loc, op.getSource());
+      subGroupAddr = arith::IndexCastUIOp::create(
+          rewriter, loc, rewriter.getI64Type(), subGroupAddr);
+    } else {
+      subGroupAddr = adaptor.getSource();
+    }
+    auto laneAddr =
+        addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
+    rewriter.replaceOp(op, laneAddr);
+    return success();
+  }
+};
+
+class UpdateOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateOffsetOp op,
+                  xegpu::UpdateOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    Value newOffsetForLane =
+        addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
+                  getElemByteSize(op));
+    rewriter.replaceOp(op, newOffsetForLane);
+    return success();
+  }
+};
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdescTy = op.getTensorDescType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrI64;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      basePtrI64 = adaptor.getSource();
+    } else {
+      basePtrI64 = adaptor.getDest();
+    }
+    Value offsets = adaptor.getOffsets();
+    Value mask = adaptor.getMask();
+    if (offsets) {
+      VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+      if (offsetsVecTy) {
+        // Offset needs be scalar.
+        return rewriter.notifyMatchFailure(op,
+                                           "Expected offsets to be a scalar.");
+      } else {
+        basePtrI64 =
+            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+      }
+    }
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    VectorType srcOrDstVecTy = op.getValueType();
+    VectorType srcOrDstFlatVecTy = VectorType::get(
+        srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+    Value maskForLane;
+    VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
+    if (maskVecTy) {
+      return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
+    } else
+      maskForLane = mask;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
+                                         maskForLane, true, true);
+      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      Value loaded =
+          LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
+      loaded.getDefiningOp()->setAttr(
+          "cache_control", xevm::LoadCacheControlAttr::get(
+                               ctxt, translateLoadXeGPUCacheHint(
+                                         op.getL1Hint(), op.getL3Hint())));
+      if (srcOrDstVecTy != srcOrDstFlatVecTy) {
+        loaded =
+            vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
+      }
+      scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+      rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      // If mask is false, we yield a vector of zeros.
+      auto eTy = srcOrDstVecTy.getElementType();
+      loaded = arith::ConstantOp::create(
+          rewriter, loc,
+          eTy.isFloat()
+              ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
+              : DenseElementsAttr::get(srcOrDstVecTy,
+                                       IntegerAttr::get(eTy, 0)));
+      scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+      rewriter.replaceOp(op, ifOp.getResult(0));
+    } else {
+      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
+      auto body = ifOp.getBody();
+      rewriter.setInsertionPointToStart(body);
+      VectorType valTy = op.getValue().getType();
+      Value srcFlatVec = op.getValue();
+      if (valTy != srcOrDstFlatVecTy) {
+        srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
+                                                 srcOrDstFlatVecTy, srcFlatVec);
+      }
+      auto storeOp =
+          LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+      storeOp.getOperation()->setAttr(
+          "cache_control", xevm::StoreCacheControlAttr::get(
+                               ctxt, translateStoreXeGPUCacheHint(
+                                         op.getL1Hint(), op.getL3Hint())));
+      rewriter.eraseOp(op);
+    }
+    return success();
+  }
+};
+
+class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdescTy = op.getTensorDescType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrI64 = adaptor.getSource();
+    Value offsets = adaptor.getOffsets();
+    if (offsets) {
+      VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+      if (offsetsVecTy) {
+        // Offset needs be scalar.
+        return rewriter.notifyMatchFailure(op,
+                                           "Expected offsets to be a scalar.");
+      } else {
+        basePtrI64 =
+            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+      }
+    }
+    Value ptrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    xevm::PrefetchOp::create(
+        rewriter, loc, ptrLLVM,
+        xevm::LoadCacheControlAttr::get(
+            ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
+    switch (op.getFenceScope()) {
+    case xegpu::FenceScope::Workgroup:
+      memScope = xevm::MemScope::WORKGROUP;
+      break;
+    case xegpu::FenceScope::GPU:
+      memScope = xevm::MemScope::DEVICE;
+      break;
+      llvm_unreachable("Unknown XeGPU fence scope.");
+    }
+    xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
+    switch (op.getMemoryKind()) {
+    case xegpu::MemorySpace::Global:
+      addrSpace = xevm::AddrSpace::GLOBAL;
+      break;
+    case xegpu::MemorySpace::SLM:
+      addrSpace = xevm::AddrSpace::SHARED;
+      break;
+      llvm_unreachable("Unknown XeGPU fence scope.");
+    }
+    xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto aTy = cast<VectorType>(op.getLhs().getType());
+    auto bTy = cast<VectorType>(op.getRhs().getType());
+    auto resultType = cast<VectorType>(op.getResultType());
+
+    auto encodePrecision = [&](Type type) -> xevm::ElemType {
+      if (type == rewriter.getBF16Type())
+        return xevm::ElemType::BF16;
+      else if (type == rewriter.getF16Type())
+        return xevm::ElemType::F16;
+      else if (type == rewriter.getTF32Type())
+        return xevm::ElemType::TF32;
+      else if (type.isInteger(8)) {
+        if (type.isUnsignedInteger())
+          return xevm::ElemType::U8;
+        return xevm::ElemType::S8;
+      } else if (type == rewriter.getF32Type())
+        return xevm::ElemType::F32;
+      else if (type.isInteger(32))
+        return xevm::ElemType::S32;
+      llvm_unreachable("add more support for ElemType");
+    };
+    xevm::ElemType precATy = encodePrecision(aTy.getElementType());
+    xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
+    Value c = op.getAcc();
+    if (!c) {
+      auto elementTy = resultType.getElementType();
+      Attribute initValueAttr;
+      if (isa<FloatType>(elementTy))
+        initValueAttr = FloatAttr::get(elementTy, 0.0);
+      else
+        initValueAttr = IntegerAttr::get(elementTy, 0);
+      c = arith::ConstantOp::create(
+          rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
+    }
+
+    Value aVec = op.getLhs();
+    Value bVec = op.getRhs();
+    auto cvecty = cast<VectorType>(c.getType());
+    xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
+    xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
+    VectorType cNty =
+        VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
+    if (cvecty != cNty)
+      c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
+    // below are uArch dependent values, should move away from hardcoding
+    constexpr int32_t systolicDepth{8};
+    constexpr int32_t executionSize{16};
+    Value dpasRes = xevm::MMAOp::create(
+        rewriter, loc, cNty, aVec, bVec, c,
+        xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
+                                systolicDepth *
+                                    getNumOperandsPerDword(precATy)),
+        xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
+    if (cvecty != cNty)
+      dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
+    rewriter.replaceOp(op, dpasRes);
+    return success();
+  }
+
+private:
+  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+    switch (pTy) {
+    case xevm::ElemType::TF32:
+      return 1;
+    case xevm::ElemType::BF16:
+    case xevm::ElemType::F16:
+      return 2;
+    case xevm::ElemType::U8:
+    case xevm::ElemType::S8:
+      return 4;
+    default:
+      llvm_unreachable("unsupported xevm::ElemType");
+    }
+  }
+};
+
+static std::optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
+  switch (arithKind) {
+  case arith::AtomicRMWKind::addf:
+    return LLVM::AtomicBinOp::fadd;
+  case arith::AtomicRMWKind::addi:
+    return LLVM::AtomicBinOp::add;
+  case arith::AtomicRMWKind::assign:
+    return LLVM::AtomicBinOp::xchg;
+  case arith::AtomicRMWKind::maximumf:
+    return LLVM::AtomicBinOp::fmax;
+  case arith::AtomicRMWKind::maxs:
+    return LLVM::AtomicBinOp::max;
+  case arith::AtomicRMWKind::maxu:
+    return LLVM::AtomicBinOp::umax;
+  case arith::AtomicRMWKind::minimumf:
+    return LLVM::AtomicBinOp::fmin;
+  case arith::AtomicRMWKind::mins:
+    return LLVM::AtomicBinOp::min;
+  case arith::AtomicRMWKind::minu:
+    return LLVM::AtomicBinOp::umin;
+  case arith::AtomicRMWKind::ori:
+    return LLVM::AtomicBinOp::_or;
+  case arith::AtomicRMWKind::andi:
+    return LLVM::AtomicBinOp::_and;
+  default:
+    return std::nullopt;
+  }
+  llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdesc = op.getTensorDesc().getType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
+    Value basePtrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
+    VectorType srcOrDstFlatVecTy = VectorType::get(
+        srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+    Value srcFlatVec = vector::ShapeCastOp::create(
+        rewriter, loc, srcOrDstFlatVecTy, op.getValue());
+    auto atomicKind = matchSimpleAtomicOp(op.getKind());
+    assert(atomicKind.has_value());
+    Value resVec = srcFlatVec;
+    for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
+      auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
+      Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+                                           rewriter.getIndexAttr(i));
+      Value currPtr =
+          LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
+                              srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
+      Value newVal =
+          LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
+                                    val, LLVM::AtomicOrdering::seq_cst);
+      resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
+    }
+    rewriter.replaceOp(op, resVec);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeGPUToXeVMPass
+    : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    LLVMTypeConverter typeConverter(&getContext());
+    typeConverter.addConversion([&](VectorType type) -> Type {
+      unsigned rank = type.getRank();
+      auto elemType = type.getElementType();
+      // If the element type is index, convert it to i64.
+      if (llvm::isa<IndexType>(elemType))
+        elemType = IntegerType::get(&getContext(), 64);
+      // If the vector is a scalar or has a single element, return the element
+      if (rank < 1 || type.getNumElements() == 1)
+        return elemType;
+      // Otherwise, convert the vector to a flat vector type.
+      unsigned sum = 1;
+      for (unsigned i = 0; i < rank; i++) {
+        sum *= type.getShape()[i];
+      }
+      return VectorType::get(sum, elemType);
+    });
+    typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+      if (type.isScattered()) {
+        return IntegerType::get(&getContext(), 64);
+      }
+      auto i32Type = IntegerType::get(&getContext(), 32);
+      return VectorType::get(8, i32Type);
+    });
+
+    auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (input.getType() == builder.getIntegerType(64, false)) {
+        Value cast =
+            index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+                .getResult();
+        return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
+      }
+      return {};
+    };
+
+    auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                                          Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+        if (vecTy.getNumElements() == 1) {
+          // If the vector has a single element, return the element type.
+          Value cast =
+              vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          if (vecTy.getElementType() == builder.getIndexType())
+            cast = arith::IndexCastOp::create(builder, loc, type, cast)
+                       .getResult();
+          return cast;
+        }
+      }
+      return {};
+    };
+    typeConverter.addSourceMaterialization(ui64MaterializationCast);
+    typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+    typeConverter.addTargetMaterialization(ui64MaterializationCast);
+    typeConverter.addTargetMaterialization(vector1DMaterializationCast);
----------------
adam-smnk wrote:

It could be nice to make expose these by moving the custom conversion into a separate `populate*TypeConversions` function. Then others can still use the conversion patterns in custom passes.

https://github.com/llvm/llvm-project/pull/154556


More information about the Mlir-commits mailing list