[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)
Adam Siemieniuk
llvmlistbot at llvm.org
Wed Aug 27 10:51:13 PDT 2025
================
@@ -0,0 +1,1019 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ TensorOffsetW = 4, // Tensor offset W (i32)
+ TensorOffsetH = 5 // Tensor offset H (i32)
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+}
+
+// Get same bitwidth flat vector type of new element type.
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+ Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+static xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+static xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
+ // Op is lowered to a code sequence that populates payload.
+ // Payload is a 8xi32 vector. Offset to individual fields are defined in
+ // NdTdescOffset enum.
+ Type payloadElemTy = rewriter.getI32Type();
+ VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ Type i64Ty = rewriter.getI64Type();
+ // 4xi64 view is used for inserting the base pointer.
+ VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ // Initialize payload to zero.
+ Value payload = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+ Value baseAddr;
+ Value baseShapeW;
+ Value baseShapeH;
+ Value offsetW;
+ Value offsetH;
+
+ // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ // Descriptor shape is expected to be 2D.
+ int64_t rank = mixedSizes.size();
+ if (rank != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+ auto sourceTy = source.getType();
+ auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+ // If source is a memref, we need to extract the aligned pointer as index.
+ // Pointer type is passed as i32 or i64 by type converter.
+ if (sourceMemrefTy) {
+ if (!sourceMemrefTy.hasStaticShape()) {
+ op.emitError() << "Expected static memref shape.";
----------------
adam-smnk wrote:
Match failure?
https://github.com/llvm/llvm-project/pull/154556
More information about the Mlir-commits
mailing list