[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)
Sang Ik Lee
llvmlistbot at llvm.org
Wed Aug 27 10:03:40 PDT 2025
================
@@ -0,0 +1,1018 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ TensorOffsetW = 4, // Tensor offset W (i32)
+ TensorOffsetH = 5 // Tensor offset H (i32)
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+}
+
+// Get same bitwidth flat vector type of new element type.
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+ Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+static xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+static xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
----------------
silee2 wrote:
That is intentional, adaptor.getSource() for memref yields an LLVM struct type which is not easy to use. I'll try changing the type converter to override LLVM type converter behavior and not convert memref type.
https://github.com/llvm/llvm-project/pull/154556
More information about the Mlir-commits
mailing list