[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)
Sang Ik Lee
llvmlistbot at llvm.org
Wed Aug 27 14:36:52 PDT 2025
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/154556
>From 9428381f89fb83dd872e67819c8d2ac2c74150eb Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 14 Jul 2025 18:54:41 +0000
Subject: [PATCH 01/17] Add XeGPUToXeVM conversion pass and tests.
---
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 12 +
.../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h | 27 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../lib/Conversion/XeGPUToXeVM/CMakeLists.txt | 25 +
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 932 ++++++++++++++++++
.../XeGPUToXeVM/create_nd_tdesc.mlir | 48 +
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 17 +
mlir/test/Conversion/XeGPUToXeVM/fence.mlir | 15 +
.../Conversion/XeGPUToXeVM/loadstore_nd.mlir | 71 ++
.../XeGPUToXeVM/loadstoreprefetch.mlir | 357 +++++++
.../Conversion/XeGPUToXeVM/prefetch_nd.mlir | 40 +
.../Conversion/XeGPUToXeVM/update_offset.mlir | 25 +
13 files changed, 1571 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/fence.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 91b2ecf8922a3..da061b269daf7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -82,6 +82,7 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2058aba7f9e37..323af3e97e2d4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
let dependentDialects = ["LLVM::LLVMDialect"];
}
+//===----------------------------------------------------------------------===//
+// XeGPUToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
+ let summary = "Convert XeGPU to XeVM dialect";
+ let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect",
+ "memref::MemRefDialect", "arith::ArithDialect",
+ "LLVM::LLVMDialect", "index::IndexDialect",
+ "gpu::GPUDialect", "scf::SCFDialect"];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
new file mode 100644
index 0000000000000..fb23d24b0161b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -0,0 +1,27 @@
+//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeGPUToXeVMConversionPatterns(
+ mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 171f7169fd41d..134fe8e14ca38 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..ed54b0bb5ee81
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+ XeGPUToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRVectorDialect
+ MLIRArithDialect
+ MLIRIndexDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000000000..380409afbc62e
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,932 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+ BasePtr = 0,
+ BaseShapeW = 2,
+ BaseShapeH = 3,
+ TensorOffsetW = 4,
+ TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+ llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+template <typename T>
+std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
+ assert(!denseAttr.empty());
+ const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
+ if (denseAttr.size() < 2)
+ return {true, 0, intercept};
+ const T slope{denseAttr[1] - denseAttr[0]};
+ for (size_t i = 1; i < denseAttr.size(); ++i)
+ if (denseAttr[i] - denseAttr[i - 1] != slope)
+ return {false, 0, 0};
+ return {true, static_cast<int32_t>(slope), intercept};
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal =
+ L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L3hintVal =
+ L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal =
+ L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L3hintVal =
+ L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
+ Type payloadElemTy = rewriter.getI32Type();
+ Type i64Ty = rewriter.getI64Type();
+ VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ Value payload = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+ Value baseAddr;
+ Value baseShapeW;
+ Value baseShapeH;
+ Value offsetW;
+ Value offsetH;
+
+ bool sourceIsMemref = false;
+ auto sourceTy = source.getType();
+ int64_t rank;
+ if (isa<MemRefType>(sourceTy)) {
+ sourceIsMemref = true;
+ baseAddr =
+ memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+ if (!sourceMemrefTy.hasStaticShape()) {
+ op.emitError() << "Expected static memref shape.";
+ return failure();
+ }
+ rank = sourceMemrefTy.getRank();
+ if (rank != 2) {
+ op.emitError() << "Expected a 2D memref.";
+ return failure();
+ }
+ } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+ rank = op.getMixedSizes().size();
+ } else {
+ op.emitError() << "Expected source to be a 2D memref or ui64.";
+ return failure();
+ }
+ auto createOffset = [&](unsigned idx) -> Value {
+ Value val;
+ OpFoldResult ofr = op.getMixedOffsets()[idx];
+ if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+ val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+ val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+ } else {
+ int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+ val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+ }
+ return val;
+ };
+ auto offsets = op.getMixedOffsets();
+ if (offsets.size() == 2) {
+ offsetW = createOffset(rank - 1);
+ offsetH = createOffset(rank - 2);
+ } else {
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ }
+ auto createShape = [&](unsigned idx) -> Value {
+ Value val;
+ OpFoldResult ofr = op.getMixedSizes()[idx];
+ if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+ val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+ val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+ } else {
+ int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+ val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+ }
+ return val;
+ };
+ if (sourceIsMemref) {
+ auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+ baseShapeW = arith::ConstantIntOp::create(
+ rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+ baseShapeH = arith::ConstantIntOp::create(
+ rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+ baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ } else {
+ baseShapeW = createShape(rank - 1);
+ baseShapeH = createShape(rank - 2);
+ baseAddr = adaptor.getSource();
+ }
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+ payLoadAsI64 =
+ vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+ static_cast<int>(NdDescI32Layout::BasePtr));
+ payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+ static_cast<int>(NdDescI32Layout::BaseShapeW));
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+ static_cast<int>(NdDescI32Layout::BaseShapeH));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetW, payload,
+ static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetH, payload,
+ static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+class UpdateNdOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+ xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto offsets = op.getOffsets();
+ auto tdesc = adaptor.getTensorDesc();
+ for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+ auto offset = offsets[offsetDim];
+ if (auto cst =
+ dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+ if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+ attr && !attr.getInt())
+ continue;
+ const int offsetPos =
+ static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+ : NdDescI32Layout::TensorOffsetH);
+ auto oldOffset =
+ vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+ offset = arith::IndexCastUIOp::create(rewriter, loc,
+ rewriter.getI32Type(), offset);
+ auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+ tdesc =
+ vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+ }
+ rewriter.replaceOp(op, tdesc);
+ return success();
+ }
+};
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+
+ auto tdesc = adaptor.getTensorDesc();
+ auto tdescTy = op.getTensorDescType();
+ if (tdescTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ }
+
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdDescI32Layout::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+ // Offsets can come from three sources:
+ // 1. Constant offsets, which are provided by the op.
+ // 2. Offsets as operands, which are provided by the op.
+ // 3. Offsets extracted from the tensor descriptor.
+ Value offsetW;
+ Value offsetH;
+ auto cOffsets = op.getConstOffsets();
+ auto offsets = op.getOffsets();
+ if (cOffsets) {
+ offsetW = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+ offsetH = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+ } else if (offsets.size() != 0) {
+ // offsets are provided as operands
+ if (offsets[0].getType() != rewriter.getI32Type()) {
+ if (offsets[0].getType() != rewriter.getIndexType()) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected offsets to be of type i32 or index.");
+ }
+ offsetW = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), offsets[0]);
+ } else {
+ offsetW = offsets[0];
+ }
+ if (offsets[1].getType() != rewriter.getI32Type()) {
+ if (offsets[1].getType() != rewriter.getIndexType()) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected offsets to be of type i32 or index.");
+ }
+ offsetH = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), offsets[1]);
+ } else {
+ offsetH = offsets[1];
+ }
+ } else {
+ // If offsets are not available, we need to extract them from the tensor
+ // descriptor.
+ offsetW = vector::ExtractOp::create(
+ rewriter, loc, tdesc,
+ static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ offsetH = vector::ExtractOp::create(
+ rewriter, loc, tdesc,
+ static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ }
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ auto elemType = tdescTy.getElementType();
+ auto elemBitSize = elemType.getIntOrFloatBitWidth();
+ // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
+ // elemBitSize);
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ auto tileW = tdescTy.getDimSize(1);
+ auto tileH = tdescTy.getDimSize(0);
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType srcFlatVecTy =
+ VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
+ Value srcFlatVec = op.getValue();
+ srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
+ rewriter.getIntegerType(elemBitSize));
+ srcFlatVec =
+ vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+ offsetH, elemBitSize, tileW, tileH, vblocks,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ return success();
+ }
+};
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
+ xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
+int64_t getElemByteSize(OpType op) {
+ // Get the element byte size from the tensor descriptor.
+ auto elemBitWidth =
+ op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
+ return elemBitWidth / 8;
+}
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+ Value baseAddr, Value offset,
+ int64_t elemByteSize) -> Value {
+ Value byteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+ Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+ return newAddr;
+};
+
+class CreateDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto offsets = adaptor.getOffsets();
+ // Source type can be a 1D memref or ui64
+ // Using "op" instead of "adaptor" since we want to access memref type
+ // instead of LLVM struct type.
+ auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
+ Value subGroupAddr;
+ if (memrefTy) {
+ subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, op.getSource());
+ subGroupAddr = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), subGroupAddr);
+ } else {
+ subGroupAddr = adaptor.getSource();
+ }
+ auto laneAddr =
+ addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
+ rewriter.replaceOp(op, laneAddr);
+ return success();
+ }
+};
+
+class UpdateOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateOffsetOp op,
+ xegpu::UpdateOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ Value newOffsetForLane =
+ addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
+ getElemByteSize(op));
+ rewriter.replaceOp(op, newOffsetForLane);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ Value basePtrI64;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ basePtrI64 = adaptor.getSource();
+ } else {
+ basePtrI64 = adaptor.getDest();
+ }
+ Value offsets = adaptor.getOffsets();
+ Value mask = adaptor.getMask();
+ if (offsets) {
+ VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+ if (offsetsVecTy) {
+ // Offset needs be scalar.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+ }
+ }
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ VectorType srcOrDstVecTy = op.getValueType();
+ VectorType srcOrDstFlatVecTy = VectorType::get(
+ srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+ Value maskForLane;
+ VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
+ if (maskVecTy) {
+ return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
+ } else
+ maskForLane = mask;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
+ maskForLane, true, true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value loaded =
+ LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
+ loaded.getDefiningOp()->setAttr("cache_control",
+ xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ if (srcOrDstVecTy != srcOrDstFlatVecTy) {
+ loaded =
+ vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
+ }
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ // If mask is false, we yield a vector of zeros.
+ auto eTy = srcOrDstVecTy.getElementType();
+ loaded = arith::ConstantOp::create(
+ rewriter, loc,
+ eTy.isFloat()
+ ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
+ : DenseElementsAttr::get(srcOrDstVecTy,
+ IntegerAttr::get(eTy, 0)));
+ scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+ rewriter.replaceOp(op, ifOp.getResult(0));
+ } else {
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
+ auto body = ifOp.getBody();
+ rewriter.setInsertionPointToStart(body);
+ VectorType valTy = op.getValue().getType();
+ Value srcFlatVec = op.getValue();
+ if (valTy != srcOrDstFlatVecTy) {
+ srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
+ srcOrDstFlatVecTy, srcFlatVec);
+ }
+ auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+ storeOp.getOperation()->setAttr(
+ "cache_control",
+ xevm::StoreCacheControlAttr::get(ctxt,
+ translateStoreXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
+class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdescTy = op.getTensorDescType();
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ Value basePtrI64 = adaptor.getSource();
+ Value offsets = adaptor.getOffsets();
+ if (offsets) {
+ VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+ if (offsetsVecTy) {
+ // Offset needs be scalar.
+ return rewriter.notifyMatchFailure(op,
+ "Expected offsets to be a scalar.");
+ } else {
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+ }
+ }
+ Value ptrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ xevm::PrefetchOp::create(
+ rewriter, loc, ptrLLVM,
+ xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
+ switch (op.getFenceScope()) {
+ case xegpu::FenceScope::Workgroup:
+ memScope = xevm::MemScope::WORKGROUP;
+ break;
+ case xegpu::FenceScope::GPU:
+ memScope = xevm::MemScope::DEVICE;
+ break;
+ llvm_unreachable("Unknown XeGPU fence scope.");
+ }
+ xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
+ switch (op.getMemoryKind()) {
+ case xegpu::MemorySpace::Global:
+ addrSpace = xevm::AddrSpace::GLOBAL;
+ break;
+ case xegpu::MemorySpace::SLM:
+ addrSpace = xevm::AddrSpace::SHARED;
+ break;
+ llvm_unreachable("Unknown XeGPU fence scope.");
+ }
+ xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto aTy = cast<VectorType>(op.getLhs().getType());
+ auto bTy = cast<VectorType>(op.getRhs().getType());
+ auto resultType = cast<VectorType>(op.getResultType());
+
+ auto encodePrecision = [&](Type type) -> xevm::ElemType {
+ if (type == rewriter.getBF16Type())
+ return xevm::ElemType::BF16;
+ else if (type == rewriter.getF16Type())
+ return xevm::ElemType::F16;
+ else if (type == rewriter.getTF32Type())
+ return xevm::ElemType::TF32;
+ else if (type.isInteger(8)) {
+ if (type.isUnsignedInteger())
+ return xevm::ElemType::U8;
+ return xevm::ElemType::S8;
+ } else if (type == rewriter.getF32Type())
+ return xevm::ElemType::F32;
+ else if (type.isInteger(32))
+ return xevm::ElemType::S32;
+ llvm_unreachable("add more support for ElemType");
+ };
+ xevm::ElemType precATy = encodePrecision(aTy.getElementType());
+ xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
+ // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy);
+ // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy);
+ Value c = op.getAcc();
+ if (!c) {
+ auto elementTy = resultType.getElementType();
+ Attribute initValueAttr;
+ if (isa<FloatType>(elementTy))
+ initValueAttr = FloatAttr::get(elementTy, 0.0);
+ else
+ initValueAttr = IntegerAttr::get(elementTy, 0);
+ c = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
+ }
+
+ Value aVec = op.getLhs();
+ Value bVec = op.getRhs();
+ auto cvecty = cast<VectorType>(c.getType());
+ xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
+ xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
+ // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy);
+ // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy);
+ VectorType cNty =
+ VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
+ if (cvecty != cNty)
+ c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
+ // below are uArch dependent values, should move away from hardcoding
+ constexpr int32_t systolicDepth{8};
+ constexpr int32_t executionSize{16};
+ Value dpasRes = xevm::MMAOp::create(
+ rewriter, loc, cNty, aVec, bVec, c,
+ xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
+ systolicDepth *
+ getNumOperandsPerDword(precATy)),
+ xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
+ if (cvecty != cNty)
+ dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
+ rewriter.replaceOp(op, dpasRes);
+ return success();
+ }
+
+private:
+ static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+ switch (pTy) {
+ case xevm::ElemType::TF32:
+ return 1;
+ case xevm::ElemType::BF16:
+ case xevm::ElemType::F16:
+ return 2;
+ case xevm::ElemType::U8:
+ case xevm::ElemType::S8:
+ return 4;
+ default:
+ llvm_unreachable("unsupported xevm::ElemType");
+ }
+ }
+};
+
+static std::optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
+ switch (arithKind) {
+ case arith::AtomicRMWKind::addf:
+ return LLVM::AtomicBinOp::fadd;
+ case arith::AtomicRMWKind::addi:
+ return LLVM::AtomicBinOp::add;
+ case arith::AtomicRMWKind::assign:
+ return LLVM::AtomicBinOp::xchg;
+ case arith::AtomicRMWKind::maximumf:
+ return LLVM::AtomicBinOp::fmax;
+ case arith::AtomicRMWKind::maxs:
+ return LLVM::AtomicBinOp::max;
+ case arith::AtomicRMWKind::maxu:
+ return LLVM::AtomicBinOp::umax;
+ case arith::AtomicRMWKind::minimumf:
+ return LLVM::AtomicBinOp::fmin;
+ case arith::AtomicRMWKind::mins:
+ return LLVM::AtomicBinOp::min;
+ case arith::AtomicRMWKind::minu:
+ return LLVM::AtomicBinOp::umin;
+ case arith::AtomicRMWKind::ori:
+ return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::andi:
+ return LLVM::AtomicBinOp::_and;
+ default:
+ return std::nullopt;
+ }
+ llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ auto tdesc = op.getTensorDesc().getType();
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
+ Value basePtrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
+ VectorType srcOrDstFlatVecTy = VectorType::get(
+ srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+ Value srcFlatVec = vector::ShapeCastOp::create(
+ rewriter, loc, srcOrDstFlatVecTy, op.getValue());
+ auto atomicKind = matchSimpleAtomicOp(op.getKind());
+ assert(atomicKind.has_value());
+ Value resVec = srcFlatVec;
+ for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
+ auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
+ Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+ rewriter.getIndexAttr(i));
+ Value currPtr =
+ LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
+ srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
+ Value newVal =
+ LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
+ val, LLVM::AtomicOrdering::seq_cst);
+ resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
+ }
+ rewriter.replaceOp(op, resVec);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeGPUToXeVMPass
+ : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ LLVMTypeConverter typeConverter(&getContext());
+ typeConverter.addConversion([&](VectorType type) -> Type {
+ unsigned rank = type.getRank();
+ auto elemType = type.getElementType();
+ // If the element type is index, convert it to i64.
+ if (llvm::isa<IndexType>(elemType))
+ elemType = IntegerType::get(&getContext(), 64);
+ // If the vector is a scalar or has a single element, return the element
+ if (rank < 1 || type.getNumElements() == 1)
+ return elemType;
+ // Otherwise, convert the vector to a flat vector type.
+ unsigned sum = 1;
+ for (unsigned i = 0; i < rank; i++) {
+ sum *= type.getShape()[i];
+ }
+ return VectorType::get(sum, elemType);
+ });
+ typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ if (type.isScattered()) {
+ return IntegerType::get(&getContext(), 64);
+ }
+ auto i32Type = IntegerType::get(&getContext(), 32);
+ return VectorType::get(8, i32Type);
+ });
+
+ auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(64, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
+ }
+ return {};
+ };
+
+ auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+ if (vecTy.getNumElements() == 1) {
+ // If the vector has a single element, return the element type.
+ Value cast =
+ vector::ExtractOp::create(builder, loc, input, 0).getResult();
+ if (vecTy.getElementType() == builder.getIndexType())
+ cast = arith::IndexCastOp::create(builder, loc, type, cast)
+ .getResult();
+ return cast;
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(ui64MaterializationCast);
+ typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+ typeConverter.addTargetMaterialization(ui64MaterializationCast);
+ typeConverter.addTargetMaterialization(vector1DMaterializationCast);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
+ vector::VectorDialect, arith::ArithDialect,
+ memref::MemRefDialect, gpu::GPUDialect,
+ index::IndexDialect>();
+ target.addIllegalDialect<xegpu::XeGPUDialect>();
+
+ RewritePatternSet patterns(&getContext());
+ populateXeGPUToXeVMConversionPatterns(patterns, typeConverter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+void mlir::populateXeGPUToXeVMConversionPatterns(
+ RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) {
+ patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
+ LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
+ AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+ LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
+ LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
+ typeConverter, patterns.getContext());
+ patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
+ patterns.getContext());
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
new file mode 100644
index 0000000000000..4fba920f023c4
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @create_nd_tdesc {
+ // CHECK-LABEL: gpu.func @create_nd_tdesc
+ // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64
+ // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
+ gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+ %stride1: index, %stride2: index) kernel {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
+ // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64
+ // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32
+ // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64
+ // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32
+ // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR9:.*]] = vector.insert %[[VAR3]], %[[VAR8]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR10:.*]] = vector.insert %[[VAR5]], %[[VAR9]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[C0_I32]], %[[VAR10]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR12:.*]] = vector.insert %[[C0_I32_0]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
+ : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR17:.*]] = vector.insert %[[C16_I32]], %[[VAR16]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR18:.*]] = vector.insert %[[C8_I32]], %[[VAR17]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[C0_I32_2]], %[[VAR18]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR20:.*]] = vector.insert %[[C0_I32_3]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
new file mode 100644
index 0000000000000..15940fc4aca26
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+
+gpu.module @load_store_check {
+ //CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>) -> vector<8xf32>
+ func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
+ // Loads are checked in a separate test.
+ // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
+ // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+ %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+ : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+ return %d : vector<8xf32>
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/fence.mlir b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
new file mode 100644
index 0000000000000..cedfcace398a6
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @fence_check {
+ gpu.func @fence(%dst: memref<8x16xf32, 1>) kernel {
+ %tid_x = gpu.thread_id x
+ %tid_x_i32 = arith.index_cast %tid_x : index to i32
+ %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+
+ // CHECK: xevm.memfence <{addrspace = #xevm.addr_space<global>, scope = #xevm.mem_scope<workgroup>}>
+ xegpu.fence memory_kind = global, fence_scope = workgroup
+ %c0 = arith.constant 0 : index
+ memref.store %tid_x_f32, %dst[%c0, %c0] : memref<8x16xf32, 1>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
new file mode 100644
index 0000000000000..c692da632d458
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @load_store_check {
+ gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+
+ //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32
+ //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
+ //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+ //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
+
+ %tid_x = gpu.thread_id x
+ %tid_x_i32 = arith.index_cast %tid_x : index to i32
+ %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+ //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+ %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+ // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+ //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32
+ //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
+ //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
+ //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
+ //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
+ //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
new file mode 100644
index 0000000000000..f6d023307313a
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -0,0 +1,357 @@
+// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+ %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
+ // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32>
+ // CHECK: scf.yield %[[VAR9]] : vector<2xf32>
+ // CHECK: } else {
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+ // CHECK: scf.yield %[[CST_1]] : vector<2xf32>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32>
+ // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
+ // CHECK: scf.yield %[[VAR9]] : f32
+ // CHECK: } else {
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+ // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
+ // CHECK: scf.yield %[[VAR8]] : f32
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) {
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16>
+ // CHECK: scf.yield %[[VAR8]] : vector<8xf16>
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
+ // CHECK: scf.yield %[[CST_0]] : vector<8xf16>
+ %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_load_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+ // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
+ %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64
+ // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
+ // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) {
+ // CHECK: %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+ // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16>
+ // CHECK: scf.yield %[[VAR12]] : vector<8xf16>
+ // CHECK: } else {
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
+ // CHECK: scf.yield %[[CST_0]] : vector<8xf16>
+ %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+ %2 = arith.constant dense<2.9>: vector<2xf32>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+ %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR4]] {
+ // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
+ %2 = arith.constant dense<2.9>: vector<2xf16>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR2]] {
+ // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+ %2 = arith.constant dense<2.9>: vector<1xf32>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+ %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR2]] {
+ // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1>
+ xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_store_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+ %1 = arith.constant dense<1>: vector<1xi1>
+ // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+ %2 = arith.constant dense<2.9>: vector<1xf32>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
+ %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64
+ // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
+ // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
+ // CHECK: scf.if %[[VAR4]] {
+ // CHECK: llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1>
+ xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ %0 = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_prefetch_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
+ %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+ // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
+ // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64
+ // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+ xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xindex>
+ gpu.return
+}
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
new file mode 100644
index 0000000000000..8513b4f9857fb
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
+
+gpu.module @fence_check {
+ gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+ // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
+ #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+ //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+ //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+ //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+ //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+ //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32
+ //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
+ //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
+ //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
+ //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
+ //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+ xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>,
+ #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+ gpu.return
+ }
+}
+
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
new file mode 100644
index 0000000000000..e9d7fd4cf40a6
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @update_offset {
+ // CHECK-LABEL: gpu.func @update_offset
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @update_offset(%src: memref<128xf32>) kernel {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ %offset = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+ // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
+ // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
+ // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
+ %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ , vector<1xindex>
+ gpu.return
+ }
+}
>From 61bee9f1feb386f7e8402424c8b2afba75b16062 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 15:25:41 +0000
Subject: [PATCH 02/17] Apply clang format.
---
.../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 380409afbc62e..32983152ef5bd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -558,10 +558,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value loaded =
LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
- loaded.getDefiningOp()->setAttr("cache_control",
- xevm::LoadCacheControlAttr::get(
- ctxt, translateLoadXeGPUCacheHint(
- op.getL1Hint(), op.getL3Hint())));
+ loaded.getDefiningOp()->setAttr(
+ "cache_control", xevm::LoadCacheControlAttr::get(
+ ctxt, translateLoadXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
if (srcOrDstVecTy != srcOrDstFlatVecTy) {
loaded =
vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
@@ -588,12 +588,12 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
srcOrDstFlatVecTy, srcFlatVec);
}
- auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+ auto storeOp =
+ LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
storeOp.getOperation()->setAttr(
- "cache_control",
- xevm::StoreCacheControlAttr::get(ctxt,
- translateStoreXeGPUCacheHint(
- op.getL1Hint(), op.getL3Hint())));
+ "cache_control", xevm::StoreCacheControlAttr::get(
+ ctxt, translateStoreXeGPUCacheHint(
+ op.getL1Hint(), op.getL3Hint())));
rewriter.eraseOp(op);
}
return success();
>From 4aa4cb29accaa3242cbcdb74155ed0a4eb6ec536 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 15:29:01 +0000
Subject: [PATCH 03/17] Remove commented out code.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 32983152ef5bd..89f40c22e7a68 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -380,8 +380,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
- // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
- // elemBitSize);
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
@@ -695,8 +693,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
};
xevm::ElemType precATy = encodePrecision(aTy.getElementType());
xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
- // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy);
- // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy);
Value c = op.getAcc();
if (!c) {
auto elementTy = resultType.getElementType();
@@ -714,8 +710,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto cvecty = cast<VectorType>(c.getType());
xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
- // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy);
- // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy);
VectorType cNty =
VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
if (cvecty != cNty)
>From 687e831902974da829d692d415df14b0e50a25b2 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 17:59:49 +0000
Subject: [PATCH 04/17] Remove dead code.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 -------------
1 file changed, 13 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 89f40c22e7a68..776380974c549 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -56,19 +56,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
llvm_unreachable("Unknown XeGPU memory space.");
}
-template <typename T>
-std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
- assert(!denseAttr.empty());
- const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
- if (denseAttr.size() < 2)
- return {true, 0, intercept};
- const T slope{denseAttr[1] - denseAttr[0]};
- for (size_t i = 1; i < denseAttr.size(); ++i)
- if (denseAttr[i] - denseAttr[i - 1] != slope)
- return {false, 0, 0};
- return {true, static_cast<int32_t>(slope), intercept};
-}
-
VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
auto elemType = currentVecType.getElementType();
auto currentBitWidth = elemType.getIntOrFloatBitWidth();
>From e240e47a1c3731ba2beb8af5e5e8f08f858a84c1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 21 Aug 2025 22:41:37 +0000
Subject: [PATCH 05/17] Temp save.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 128 +++++++++++++-----
.../XeGPUToXeVM/create_nd_tdesc.mlir | 4 +-
.../XeGPUToXeVM/loadstoreprefetch.mlir | 4 +-
.../XeGPUToXeVM/materializecast.mlir | 49 +++++++
.../Conversion/XeGPUToXeVM/update_offset.mlir | 6 +-
5 files changed, 150 insertions(+), 41 deletions(-)
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 776380974c549..4ff5321c1d9d2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -426,18 +427,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
}
};
-template <
- typename OpType,
- typename = std::enable_if_t<llvm::is_one_of<
- OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
- xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
-int64_t getElemByteSize(OpType op) {
- // Get the element byte size from the tensor descriptor.
- auto elemBitWidth =
- op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
- return elemBitWidth / 8;
-}
-
// Add a builder that creates
// offset * elemByteSize + baseAddr
auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
@@ -456,23 +445,23 @@ class CreateDescToXeVMPattern
LogicalResult
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
+ return rewriter.notifyMatchFailure(op,
+ "Expected element type bit width to be multiple of 8.");
+ }
auto loc = op.getLoc();
+ // offsets are provided as scalar i64 by type converter.
auto offsets = adaptor.getOffsets();
- // Source type can be a 1D memref or ui64
- // Using "op" instead of "adaptor" since we want to access memref type
- // instead of LLVM struct type.
- auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
- Value subGroupAddr;
- if (memrefTy) {
- subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, op.getSource());
- subGroupAddr = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI64Type(), subGroupAddr);
- } else {
- subGroupAddr = adaptor.getSource();
- }
+ // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
+ // But type converter will convert them to integer types.
+ Value addr = adaptor.getSource();
+ // ui32 or i32 are passed as i32 so they need to be casted to i64.
+ if (addr.getType() != rewriter.getI64Type())
+ addr = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), addr);
auto laneAddr =
- addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
+ addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
rewriter.replaceOp(op, laneAddr);
return success();
}
@@ -485,11 +474,18 @@ class UpdateOffsetToXeVMPattern
matchAndRewrite(xegpu::UpdateOffsetOp op,
xegpu::UpdateOffsetOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto eTy = op.getTensorDescType().getElementType();
+ if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
+ return rewriter.notifyMatchFailure(op,
+ "Expected element type bit width to be multiple of 8.");
+ }
auto loc = op.getLoc();
- Value newOffsetForLane =
+ // scatter descriptor is provided as scalar i64 by type converter.
+ // offsets are provided as scalar i64 by type converter.
+ Value newOffset =
addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
getElemByteSize(op));
- rewriter.replaceOp(op, newOffsetForLane);
+ rewriter.replaceOp(op, newOffset);
return success();
}
};
@@ -505,19 +501,38 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
- auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
- ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ if (tdescTy)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
Value basePtrI64;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
basePtrI64 = adaptor.getSource();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
} else {
basePtrI64 = adaptor.getDest();
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
}
+ if (basePtrI64.getType() != rewriter.getI64Type()) {
+ basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ }
+ basePtrI64.dump();
Value offsets = adaptor.getOffsets();
+ offsets.dump();
Value mask = adaptor.getMask();
+ mask.dump();
if (offsets) {
- VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
- if (offsetsVecTy) {
+ if (dyn_cast<VectorType>(offsets.getType())){
// Offset needs be scalar.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
@@ -526,8 +541,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
}
}
+ basePtrI64.dump();
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ basePtrLLVM.dump();
VectorType srcOrDstVecTy = op.getValueType();
VectorType srcOrDstFlatVecTy = VectorType::get(
srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
@@ -597,6 +614,10 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
Value basePtrI64 = adaptor.getSource();
Value offsets = adaptor.getOffsets();
+ if (basePtrI64.getType() != rewriter.getI64Type()) {
+ basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
+ }
if (offsets) {
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
if (offsetsVecTy) {
@@ -836,6 +857,26 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ typeConverter.addConversion([&](MemRefType type) -> Type {
+ // Convert MemRefType to i64 type.
+ return IntegerType::get(&getContext(), 64);
+ });
+
+ auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+
+ Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
+ builder, loc, input);
+ return arith::IndexCastUIOp::create(builder, loc, type,
+ addr).getResult();
+ }
+ return {};
+ };
auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
@@ -847,7 +888,22 @@ struct ConvertXeGPUToXeVMPass
Value cast =
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
.getResult();
- return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+ }
+ return {};
+ };
+
+ auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType() == builder.getIntegerType(32, false)) {
+ Value cast =
+ index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+ .getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
}
return {};
};
@@ -864,15 +920,19 @@ struct ConvertXeGPUToXeVMPass
Value cast =
vector::ExtractOp::create(builder, loc, input, 0).getResult();
if (vecTy.getElementType() == builder.getIndexType())
- cast = arith::IndexCastOp::create(builder, loc, type, cast)
+ cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
}
}
return {};
};
+ typeConverter.addSourceMaterialization(memrefMaterializationCast);
typeConverter.addSourceMaterialization(ui64MaterializationCast);
+ typeConverter.addSourceMaterialization(ui32MaterializationCast);
typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+ typeConverter.addTargetMaterialization(memrefMaterializationCast);
+ typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
typeConverter.addTargetMaterialization(vector1DMaterializationCast);
ConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4fba920f023c4..7f5e3527a1594 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -6,8 +6,8 @@ gpu.module @create_nd_tdesc {
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index) kernel {
- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index f6d023307313a..825a4d6368863 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -5,10 +5,10 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:.*]]: ui64
gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
new file mode 100644
index 0000000000000..a7ae4d9b7e4d2
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @materializecast {
+ // CHECK-LABEL: gpu.func @materialize_memref
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
+ // CHECK: XXX
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+ // CHECK-LABEL: gpu.func @materialize_ui64
+ // CHECK-SAME: %[[ARG0:.*]]: ui64
+ gpu.func @materialize_ui64(%src: ui64) kernel {
+ // CHECK: XXX
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+ // CHECK-LABEL: gpu.func @materialize_ui32
+ // CHECK-SAME: %[[ARG0:.*]]: ui32
+ gpu.func @materialize_ui32(%src: ui32) kernel {
+ %offset = arith.constant dense<0> : vector<1xindex>
+ //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
+ // -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+ // CHECK-LABEL: gpu.func @materialize_single_index_vector
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+ gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
+ // CHECK: XXX
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.return
+ }
+ // CHECK-LABEL: gpu.func @materialize_single_elem_vector
+ // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1>
+ gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
+ // CHECK: XXX
+ %mask = arith.constant dense<1>: vector<1xi1>
+ %offset = arith.constant dense<0> : vector<1xindex>
+ %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
index e9d7fd4cf40a6..6e59414c62582 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
@@ -4,12 +4,12 @@ gpu.module @update_offset {
// CHECK-LABEL: gpu.func @update_offset
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
gpu.func @update_offset(%src: memref<128xf32>) kernel {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
%offset = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
- // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
>From 8e507ec303c37524d793d0dcef382922a185257f Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 18:32:56 +0000
Subject: [PATCH 06/17] Adjust to latest XeGPU dialect update.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 198 ++++++++++++------
.../XeGPUToXeVM/loadstoreprefetch.mlir | 142 ++-----------
.../XeGPUToXeVM/materializecast.mlir | 39 +++-
3 files changed, 185 insertions(+), 194 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 4ff5321c1d9d2..6cfa8ac1f8fce 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -446,9 +446,10 @@ class CreateDescToXeVMPattern
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
- if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
- return rewriter.notifyMatchFailure(op,
- "Expected element type bit width to be multiple of 8.");
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
}
auto loc = op.getLoc();
// offsets are provided as scalar i64 by type converter.
@@ -458,10 +459,8 @@ class CreateDescToXeVMPattern
Value addr = adaptor.getSource();
// ui32 or i32 are passed as i32 so they need to be casted to i64.
if (addr.getType() != rewriter.getI64Type())
- addr = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI64Type(), addr);
- auto laneAddr =
- addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
+ addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
+ auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
rewriter.replaceOp(op, laneAddr);
return success();
}
@@ -475,16 +474,16 @@ class UpdateOffsetToXeVMPattern
xegpu::UpdateOffsetOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
- if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
- return rewriter.notifyMatchFailure(op,
- "Expected element type bit width to be multiple of 8.");
+ auto eBw = eTy.getIntOrFloatBitWidth();
+ if (eBw % 8 != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
}
auto loc = op.getLoc();
// scatter descriptor is provided as scalar i64 by type converter.
// offsets are provided as scalar i64 by type converter.
- Value newOffset =
- addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
- getElemByteSize(op));
+ Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
+ adaptor.getOffsets(), eBw / 8);
rewriter.replaceOp(op, newOffset);
return success();
}
@@ -501,12 +500,35 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
- LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
- ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
- if (tdescTy)
- ptrTypeLLVM = LLVM::LLVMPointerType::get(
- ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
Value basePtrI64;
+ // Load result or Store valye Type can be vector or scalar.
+ Type valOrResTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ valOrResTy = op.getResult().getType();
+ } else {
+ valOrResTy = adaptor.getValue().getType();
+ }
+ VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
+ bool hasScalarVal = !valOrResVecTy;
+ int64_t elemBitWidth =
+ hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
+ : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ }
+ int64_t elemByteSize = elemBitWidth / 8;
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy) {
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ }
+ // Base pointer can come from source (load) or dest (store).
+ // If they are memrefs, we use their memory space.
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
basePtrI64 = adaptor.getSource();
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
@@ -522,76 +544,79 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
}
}
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
if (basePtrI64.getType() != rewriter.getI64Type()) {
- basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
- basePtrI64);
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
}
- basePtrI64.dump();
Value offsets = adaptor.getOffsets();
- offsets.dump();
Value mask = adaptor.getMask();
- mask.dump();
if (offsets) {
- if (dyn_cast<VectorType>(offsets.getType())){
- // Offset needs be scalar.
+ if (dyn_cast<VectorType>(offsets.getType())) {
+ // Offset needs be scalar. Single element vector is converted to scalar
+ // by type converter.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
} else {
+ // If offsets are provided, we add them to the base pointer.
+ // Offsets are in number of elements, we need to multiply by
+ // element byte size.
basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
}
}
- basePtrI64.dump();
+ // Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
- basePtrLLVM.dump();
- VectorType srcOrDstVecTy = op.getValueType();
- VectorType srcOrDstFlatVecTy = VectorType::get(
- srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+
Value maskForLane;
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
if (maskVecTy) {
+ // Mask needs be scalar. Single element vector is converted to scalar by
+ // type converter.
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
- } else
+ } else {
maskForLane = mask;
+ }
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
- scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
+ scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
maskForLane, true, true);
+ // If mask is true,- then clause - load from memory and yield.
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ if (!hasScalarVal)
+ valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
+ valOrResVecTy.getElementType());
Value loaded =
- LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
+ LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
+ // Set cache control attribute on the load operation.
loaded.getDefiningOp()->setAttr(
"cache_control", xevm::LoadCacheControlAttr::get(
ctxt, translateLoadXeGPUCacheHint(
op.getL1Hint(), op.getL3Hint())));
- if (srcOrDstVecTy != srcOrDstFlatVecTy) {
- loaded =
- vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
- }
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
- // If mask is false, we yield a vector of zeros.
- auto eTy = srcOrDstVecTy.getElementType();
- loaded = arith::ConstantOp::create(
- rewriter, loc,
- eTy.isFloat()
- ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
- : DenseElementsAttr::get(srcOrDstVecTy,
- IntegerAttr::get(eTy, 0)));
+ // If mask is false - else clause -yield a vector of zeros.
+ auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
+ TypedAttr eVal;
+ if (eTy.isFloat())
+ eVal = FloatAttr::get(eTy, 0.0);
+ else
+ eVal = IntegerAttr::get(eTy, 0);
+ if (hasScalarVal)
+ loaded = arith::ConstantOp::create(rewriter, loc, eVal);
+ else
+ loaded = arith::ConstantOp::create(
+ rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
rewriter.replaceOp(op, ifOp.getResult(0));
} else {
+ // if mask is true, perform the store.
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
auto body = ifOp.getBody();
rewriter.setInsertionPointToStart(body);
- VectorType valTy = op.getValue().getType();
- Value srcFlatVec = op.getValue();
- if (valTy != srcOrDstFlatVecTy) {
- srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
- srcOrDstFlatVecTy, srcFlatVec);
- }
auto storeOp =
- LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
+ // Set cache control attribute on the store operation.
storeOp.getOperation()->setAttr(
"cache_control", xevm::StoreCacheControlAttr::get(
ctxt, translateStoreXeGPUCacheHint(
@@ -610,14 +635,13 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
- auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
- ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
Value basePtrI64 = adaptor.getSource();
- Value offsets = adaptor.getOffsets();
+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
if (basePtrI64.getType() != rewriter.getI64Type()) {
- basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
- basePtrI64);
+ basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ basePtrI64);
}
+ Value offsets = adaptor.getOffsets();
if (offsets) {
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
if (offsetsVecTy) {
@@ -625,12 +649,50 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
} else {
+ int64_t elemBitWidth{0};
+ int64_t elemByteSize;
+ // Element byte size can come from three sources:
+ if (tdescTy) {
+ // If tensor descriptor is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+ } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
+ // If memref is available, we use its element type to
+ // determine element byte size.
+ elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
+ } else {
+ // Otherwise, we use the provided offset byte alignment.
+ elemByteSize = *op.getOffsetAlignByte();
+ }
+ if (elemBitWidth != 0) {
+ if (elemBitWidth % 8 != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ }
+ elemByteSize = elemBitWidth / 8;
+ }
basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+ addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
}
}
+ // Default memory space is global.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+ // If tensor descriptor is available, we use its memory space.
+ if (tdescTy) {
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ }
+ // If source is a memref, we use its memory space.
+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+ auto addrSpace = memRefTy.getMemorySpaceAsInt();
+ if (addrSpace != 0)
+ ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+ }
+ // Convert base pointer (i64) to LLVM pointer type.
Value ptrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ // Create the prefetch op with cache control attribute.
xevm::PrefetchOp::create(
rewriter, loc, ptrLLVM,
xevm::LoadCacheControlAttr::get(
@@ -863,17 +925,17 @@ struct ConvertXeGPUToXeVMPass
});
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
- Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
- builder, loc, input);
- return arith::IndexCastUIOp::create(builder, loc, type,
- addr).getResult();
+ Value addr =
+ memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
+ return arith::IndexCastUIOp::create(builder, loc, type, addr)
+ .getResult();
}
return {};
};
@@ -888,7 +950,8 @@ struct ConvertXeGPUToXeVMPass
Value cast =
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
.getResult();
- return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
}
return {};
};
@@ -903,7 +966,8 @@ struct ConvertXeGPUToXeVMPass
Value cast =
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
.getResult();
- return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+ return arith::IndexCastUIOp::create(builder, loc, type, cast)
+ .getResult();
}
return {};
};
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 825a4d6368863..0f67dc290689b 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -37,15 +37,15 @@ gpu.module @test {
// CHECK-LABEL: @load_gather_memref_src_constant_offset
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -73,12 +73,12 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
// CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -99,51 +99,15 @@ gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: ve
}
// -----
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_load_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) {
- // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
- // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
- // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
- // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64
- // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
- %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex>
- -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
- // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64
- // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64
- // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
- // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
- // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) {
- // CHECK: %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
- // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16>
- // CHECK: scf.yield %[[VAR12]] : vector<8xf16>
- // CHECK: } else {
- // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
- // CHECK: scf.yield %[[CST_0]] : vector<8xf16>
- %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
- gpu.return
-}
-}
-// -----
-
gpu.module @test {
// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
// CHECK-SAME: %[[ARG0:.*]]: ui64
gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
@@ -170,17 +134,17 @@ gpu.module @test {
// CHECK-LABEL: @store_scatter_memref_src_constant_offset
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
// CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
%2 = arith.constant dense<2.9>: vector<2xf16>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
// CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -202,14 +166,15 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
%1 = arith.constant dense<1>: vector<1xi1>
// CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+ // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
%2 = arith.constant dense<2.9>: vector<1xf32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -217,8 +182,8 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset:
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
// CHECK: scf.if %[[VAR2]] {
- // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
- // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1>
+ // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+ // CHECK-SAME: : f32, !llvm.ptr<1>
xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
gpu.return
@@ -226,49 +191,15 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset:
}
// -----
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_store_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
- // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
- // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
- // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
- // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
- %1 = arith.constant dense<1>: vector<1xi1>
- // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
- %2 = arith.constant dense<2.9>: vector<1xf32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
- %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64
- // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
- // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
- // CHECK: scf.if %[[VAR4]] {
- // CHECK: llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
- // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1>
- xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
- : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
- gpu.return
-}
-}
-// -----
-
gpu.module @test {
// CHECK-LABEL: @prefetch_ui64_src_constant_offset
// CHECK-SAME: %[[ARG0:.*]]: ui64
gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+ // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
@@ -288,12 +219,12 @@ gpu.module @test {
// CHECK-LABEL: @prefetch_memref_src_constant_offset
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
%0 = arith.constant dense<0> : vector<1xindex>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
@@ -313,7 +244,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
// CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
@@ -328,30 +259,3 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
gpu.return
}
}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @prefetch_memref_src_prefetch_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
- // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
- // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
- // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
- // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
- // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
- %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
- -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
- // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
- // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
- // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64
- // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1>
- // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
- xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xindex>
- gpu.return
-}
-}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index a7ae4d9b7e4d2..8db0843de4cc1 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -1,45 +1,68 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s
gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_memref
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
- // CHECK: XXX
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+ // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64
%offset = arith.constant dense<0> : vector<1xindex>
%src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
+}
+
+// -----
+gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_ui64
// CHECK-SAME: %[[ARG0:.*]]: ui64
gpu.func @materialize_ui64(%src: ui64) kernel {
- // CHECK: XXX
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
%offset = arith.constant dense<0> : vector<1xindex>
%src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
+}
+
+// -----
+gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_ui32
// CHECK-SAME: %[[ARG0:.*]]: ui32
gpu.func @materialize_ui32(%src: ui32) kernel {
+ // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index
+ // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32
%offset = arith.constant dense<0> : vector<1xindex>
- //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
- // -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+ %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
+ -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
+}
+
+// -----
+gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_single_index_vector
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
- // CHECK: XXX
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+ // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+ // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64
%offset = arith.constant dense<0> : vector<1xindex>
%src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
gpu.return
}
+}
+
+// -----
+gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_single_elem_vector
- // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1>
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
- // CHECK: XXX
+ // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+ // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
%mask = arith.constant dense<1>: vector<1xi1>
%offset = arith.constant dense<0> : vector<1xindex>
%0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
>From d372592306858cc17a2cac6397e4755f5cdc826a Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 21:11:26 +0000
Subject: [PATCH 07/17] Temp save.
---
.../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h | 8 +-
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 119 ++++++++----------
.../XeGPUToXeVM/create_nd_tdesc.mlir | 12 +-
3 files changed, 62 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
index fb23d24b0161b..ddaaae82e03be 100644
--- a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
-#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
#include <memory>
@@ -20,8 +20,8 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
void populateXeGPUToXeVMConversionPatterns(
- mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+ const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
} // namespace mlir
-#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6cfa8ac1f8fce..19324f748bded 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -57,7 +57,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
llvm_unreachable("Unknown XeGPU memory space.");
}
-VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+ Type toElemType) {
auto elemType = currentVecType.getElementType();
auto currentBitWidth = elemType.getIntOrFloatBitWidth();
auto newBitWidth = toElemType.getIntOrFloatBitWidth();
@@ -66,13 +67,11 @@ VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
return VectorType::get(size, toElemType);
}
-xevm::LoadCacheControl
+static xevm::LoadCacheControl
translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
- auto L1hintVal =
- L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
- auto L3hintVal =
- L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
switch (L1hintVal) {
case xegpu::CachePolicy::CACHED:
if (L3hintVal == xegpu::CachePolicy::CACHED)
@@ -102,13 +101,11 @@ translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
-xevm::StoreCacheControl
+static xevm::StoreCacheControl
translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
- auto L1hintVal =
- L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
- auto L3hintVal =
- L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
switch (L1hintVal) {
case xegpu::CachePolicy::UNCACHED:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
@@ -152,10 +149,14 @@ class CreateNdDescToXeVMPattern
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto source = op.getSource();
+ // Op is lowered to a code sequence that populates payload.
+ // payload is a 8xi32 vector.
Type payloadElemTy = rewriter.getI32Type();
Type i64Ty = rewriter.getI64Type();
VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ // 4xi64 view is used for inserting the base pointer.
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ // Initialize payload to zero.
Value payload = arith::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
@@ -166,73 +167,56 @@ class CreateNdDescToXeVMPattern
Value offsetW;
Value offsetH;
- bool sourceIsMemref = false;
+ // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+ SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+ // Descriptor shape is expected to be 2D.
+ int64_t rank = mixedSizes.size();
+ if (rank != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
auto sourceTy = source.getType();
- int64_t rank;
- if (isa<MemRefType>(sourceTy)) {
- sourceIsMemref = true;
+ auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+ // If source is a memref, we need to extract the aligned pointer as index.
+ // pointer type is passed as i32 or i64 by type converter.
+ if (sourceMemrefTy) {
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
- auto sourceMemrefTy = cast<MemRefType>(sourceTy);
if (!sourceMemrefTy.hasStaticShape()) {
op.emitError() << "Expected static memref shape.";
return failure();
}
- rank = sourceMemrefTy.getRank();
- if (rank != 2) {
- op.emitError() << "Expected a 2D memref.";
- return failure();
- }
- } else if (sourceTy == rewriter.getIntegerType(64, false)) {
- rank = op.getMixedSizes().size();
} else {
- op.emitError() << "Expected source to be a 2D memref or ui64.";
- return failure();
+ baseAddr = adaptor.getSource();
}
- auto createOffset = [&](unsigned idx) -> Value {
- Value val;
- OpFoldResult ofr = op.getMixedOffsets()[idx];
- if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
- val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
- val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
- } else {
- int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
- val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
- }
+ // utility for creating offset values from op fold result.
+ auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
+ unsigned idx) -> Value {
+ Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+ val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- auto offsets = op.getMixedOffsets();
- if (offsets.size() == 2) {
- offsetW = createOffset(rank - 1);
- offsetH = createOffset(rank - 2);
- } else {
+ // Offsets can be either 2D or not provided (0 is used).
+ if (mixedOffsets.size() == 2) {
+ offsetW = createOffset(mixedOffsets, rank - 1);
+ offsetH = createOffset(mixedOffsets, rank - 2);
+ } else if (mixedOffsets.size() == 0) {
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
}
- auto createShape = [&](unsigned idx) -> Value {
- Value val;
- OpFoldResult ofr = op.getMixedSizes()[idx];
- if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
- val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
- val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
- } else {
- int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
- val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
- }
- return val;
- };
- if (sourceIsMemref) {
- auto sourceMemrefTy = cast<MemRefType>(sourceTy);
- baseShapeW = arith::ConstantIntOp::create(
- rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
- baseShapeH = arith::ConstantIntOp::create(
- rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+ // Get shape values from op fold results.
+ baseShapeW = createOffset(mixedSizes, rank - 1);
+ baseShapeH = createOffset(mixedSizes, rank - 2);
+ if (sourceMemrefTy) {
+ // cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else {
- baseShapeW = createShape(rank - 1);
- baseShapeH = createShape(rank - 2);
- baseAddr = adaptor.getSource();
+ } else if (baseAddr.getType() != i64Ty) {
+ // pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
+ // Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
payLoadAsI64 =
@@ -429,9 +413,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset,
- int64_t elemByteSize) -> Value {
+static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+ Value baseAddr, Value offset,
+ int64_t elemByteSize) -> Value {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
@@ -701,6 +685,7 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
return success();
}
};
+
class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -1007,7 +992,7 @@ struct ConvertXeGPUToXeVMPass
target.addIllegalDialect<xegpu::XeGPUDialect>();
RewritePatternSet patterns(&getContext());
- populateXeGPUToXeVMConversionPatterns(patterns, typeConverter);
+ populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
patterns, target);
if (failed(applyPartialConversion(getOperation(), target,
@@ -1021,7 +1006,7 @@ struct ConvertXeGPUToXeVMPass
// Pattern Population
//===----------------------------------------------------------------------===//
void mlir::populateXeGPUToXeVMConversionPatterns(
- RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) {
+ const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 7f5e3527a1594..ba7ece8ccbebe 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -11,10 +11,8 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
- // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64
- // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32
- // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64
- // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32
+ // CHECK: %[[VAR3:.*]] = arith.index_cast %[[ARG3]] : index to i32
+ // CHECK: %[[VAR5:.*]] = arith.index_cast %[[ARG2]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
@@ -32,8 +30,10 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
// CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32
- // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
- // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
+ // CHECK: %[[C16_I32:.*]] = arith.trunci %c16_i64 : i64 to i32
+ // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
+ // CHECK: %[[C8_I32:.*]] = arith.trunci %c8_i64 : i64 to i32
// CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64>
>From d88d676129de7c10f4eebae4e1bc5d777ea91d21 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 22:11:43 +0000
Subject: [PATCH 08/17] Update update_nd_tdesc.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 39 +++++++++----------
1 file changed, 19 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 19324f748bded..12af0af70177b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -248,27 +248,26 @@ class UpdateNdOffsetToXeVMPattern
xegpu::UpdateNdOffsetOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- auto offsets = op.getOffsets();
+ auto mixedOffsets = op.getMixedOffsets();
+ if (mixedOffsets.size() != 2)
+ return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto tdesc = adaptor.getTensorDesc();
- for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
- auto offset = offsets[offsetDim];
- if (auto cst =
- dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
- if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
- attr && !attr.getInt())
- continue;
- const int offsetPos =
- static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
- : NdDescI32Layout::TensorOffsetH);
- auto oldOffset =
- vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
- offset = arith::IndexCastUIOp::create(rewriter, loc,
- rewriter.getI32Type(), offset);
- auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- tdesc =
- vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
- }
- rewriter.replaceOp(op, tdesc);
+ // utility for updating payload offset values from op fold result.
+ auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offset);
+ Value oldOffset =
+ vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+ Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+ return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+ payloadPos);
+ };
+ auto val =
+ updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ rewriter.replaceOp(op, val);
return success();
}
};
>From 236343e189f8ebea578ce872a0f2206a31ab1536 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 22:29:06 +0000
Subject: [PATCH 09/17] Temp save.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 160 ++++++++++--------
.../Conversion/XeGPUToXeVM/loadstore_nd.mlir | 12 +-
.../XeGPUToXeVM/materializecast.mlir | 2 +-
.../Conversion/XeGPUToXeVM/prefetch_nd.mlir | 6 +-
4 files changed, 99 insertions(+), 81 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 12af0af70177b..963ab29695b1f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -39,12 +39,13 @@ using namespace mlir;
namespace {
-enum class NdDescI32Layout : uint32_t {
- BasePtr = 0,
- BaseShapeW = 2,
- BaseShapeH = 3,
- TensorOffsetW = 4,
- TensorOffsetH = 5
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ TensorOffsetW = 4, // Tensor offset W (i32)
+ TensorOffsetH = 5 // Tensor offset H (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -57,6 +58,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
llvm_unreachable("Unknown XeGPU memory space.");
}
+// Get same bitwidth flat vector type of new element type.
static VectorType encodeVectorTypeTo(VectorType currentVecType,
Type toElemType) {
auto elemType = currentVecType.getElementType();
@@ -221,20 +223,20 @@ class CreateNdDescToXeVMPattern
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
payLoadAsI64 =
vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
- static_cast<int>(NdDescI32Layout::BasePtr));
+ static_cast<int>(NdTdescOffset::BasePtr));
payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
payload =
vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
- static_cast<int>(NdDescI32Layout::BaseShapeW));
+ static_cast<int>(NdTdescOffset::BaseShapeW));
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
- static_cast<int>(NdDescI32Layout::BaseShapeH));
+ static_cast<int>(NdTdescOffset::BaseShapeH));
payload = vector::InsertOp::create(
rewriter, loc, offsetW, payload,
- static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ static_cast<int>(NdTdescOffset::TensorOffsetW));
payload = vector::InsertOp::create(
rewriter, loc, offsetH, payload,
- static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ static_cast<int>(NdTdescOffset::TensorOffsetH));
rewriter.replaceOp(op, payload);
return success();
}
@@ -249,6 +251,7 @@ class UpdateNdOffsetToXeVMPattern
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto mixedOffsets = op.getMixedOffsets();
+ // Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto tdesc = adaptor.getTensorDesc();
@@ -264,9 +267,9 @@ class UpdateNdOffsetToXeVMPattern
return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
payloadPos);
};
- auto val =
- updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
- val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ // Update offsets in the payload.
+ auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
rewriter.replaceOp(op, val);
return success();
}
@@ -293,62 +296,46 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
- Value basePtr =
- vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
- static_cast<int>(NdDescI32Layout::BasePtr));
+ Value basePtr = vector::ExtractOp::create(
+ rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
Value baseShapeW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
- // Offsets can come from three sources:
- // 1. Constant offsets, which are provided by the op.
- // 2. Offsets as operands, which are provided by the op.
- // 3. Offsets extracted from the tensor descriptor.
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ // Offsets provided in two ways:
+ // 1. Offsets are extracted from the tensor descriptor.
+ // 2. (Mixed) offsets which are provided by the op.
Value offsetW;
Value offsetH;
- auto cOffsets = op.getConstOffsets();
- auto offsets = op.getOffsets();
- if (cOffsets) {
- offsetW = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
- offsetH = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
- } else if (offsets.size() != 0) {
- // offsets are provided as operands
- if (offsets[0].getType() != rewriter.getI32Type()) {
- if (offsets[0].getType() != rewriter.getIndexType()) {
- return rewriter.notifyMatchFailure(
- op, "Expected offsets to be of type i32 or index.");
- }
- offsetW = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), offsets[0]);
- } else {
- offsetW = offsets[0];
- }
- if (offsets[1].getType() != rewriter.getI32Type()) {
- if (offsets[1].getType() != rewriter.getIndexType()) {
- return rewriter.notifyMatchFailure(
- op, "Expected offsets to be of type i32 or index.");
- }
- offsetH = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), offsets[1]);
- } else {
- offsetH = offsets[1];
- }
+ auto mixedOffsets = op.getMixedOffsets();
+ int64_t opOffsetsSize = mixedOffsets.size();
+ if (opOffsetsSize != 0 && opOffsetsSize != 2) {
+ return rewriter.notifyMatchFailure(op,
+ "Expected 2D offsets or no offsets.");
+ }
+ if (opOffsetsSize) {
+ // If mixed offsets are provided by the op convert them to i32.
+ offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
} else {
// If offsets are not available, we need to extract them from the tensor
// descriptor.
offsetW = vector::ExtractOp::create(
- rewriter, loc, tdesc,
- static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
offsetH = vector::ExtractOp::create(
- rewriter, loc, tdesc,
- static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
}
+ // Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ // Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // Compute element byte size and surface width in bytes.
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
Value elemByteSize = arith::ConstantIntOp::create(
@@ -356,23 +343,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Get tile sizes and vblocks from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(1);
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+ VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+ if (!srcVecTy) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ }
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- VectorType srcFlatVecTy =
- VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
- Value srcFlatVec = op.getValue();
- srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
- rewriter.getIntegerType(elemBitSize));
- srcFlatVec =
- vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+ Value src = adaptor.getValue();
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+ offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
@@ -412,15 +403,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset,
- int64_t elemByteSize) -> Value {
+static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
+ Value baseAddr, Value offset, int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
-};
+}
class CreateDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateDescOp> {
@@ -908,6 +898,10 @@ struct ConvertXeGPUToXeVMPass
return IntegerType::get(&getContext(), 64);
});
+ // LLVM type converter puts unrealized casts for the following cases:
+ // add materialization casts to handle them.
+
+ // Materialization to convert memref to i64
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
@@ -924,6 +918,7 @@ struct ConvertXeGPUToXeVMPass
return {};
};
+ // Materialization to convert ui64 to i64
auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
@@ -940,6 +935,7 @@ struct ConvertXeGPUToXeVMPass
return {};
};
+ // Materialization to convert ui32 to i32
auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
@@ -956,9 +952,13 @@ struct ConvertXeGPUToXeVMPass
return {};
};
- auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
- ValueRange inputs,
- Location loc) -> Value {
+ // Materialization to convert
+ // - single element 1D vector to scalar
+ // - bitcast vector of same rank
+ // - shape vector of different rank but same element type
+ auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+ ValueRange inputs,
+ Location loc) -> Value {
if (inputs.size() != 1)
return {};
auto input = inputs.front();
@@ -971,6 +971,18 @@ struct ConvertXeGPUToXeVMPass
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
.getResult();
return cast;
+ } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+ // If the target type is a vector of same rank,
+ // bitcast to the target type.
+ if (targetVecTy.getRank() == vecTy.getRank())
+ return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+ // If the target type is a vector of different rank but same element
+ // type, reshape to the target type.
+ return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+ .getResult();
+ }
}
}
return {};
@@ -978,11 +990,11 @@ struct ConvertXeGPUToXeVMPass
typeConverter.addSourceMaterialization(memrefMaterializationCast);
typeConverter.addSourceMaterialization(ui64MaterializationCast);
typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
- typeConverter.addTargetMaterialization(vector1DMaterializationCast);
+ typeConverter.addTargetMaterialization(vectorMaterializationCast);
ConversionTarget target(getContext());
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
vector::VectorDialect, arith::ArithDialect,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index c692da632d458..4c6bbf25b4728 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -20,8 +20,10 @@ gpu.module @load_store_check {
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
//CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32
- //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
+ //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
@@ -54,8 +56,10 @@ gpu.module @load_store_check {
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
//CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32
- //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
+ //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index 8db0843de4cc1..2445c4b341657 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -66,7 +66,7 @@ gpu.module @materializecast {
%mask = arith.constant dense<1>: vector<1xi1>
%offset = arith.constant dense<0> : vector<1xindex>
%0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
- : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
+ : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
gpu.return
}
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index 8513b4f9857fb..873478aed57e3 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -20,8 +20,10 @@ gpu.module @fence_check {
//CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
//CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32
- //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32
+ //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
+ //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
+ //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
//CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
//CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
>From a8cd5e08cfc536fb27f41f23820b21d0883afd8f Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 26 Aug 2025 00:49:47 +0000
Subject: [PATCH 10/17] Address comments.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 23 +++++++++----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 963ab29695b1f..6cd50a38a21a4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -152,7 +152,7 @@ class CreateNdDescToXeVMPattern
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
- // payload is a 8xi32 vector.
+ // Payload is a 8xi32 vector.
Type payloadElemTy = rewriter.getI32Type();
Type i64Ty = rewriter.getI64Type();
VectorType payloadTy = VectorType::get(8, payloadElemTy);
@@ -179,7 +179,7 @@ class CreateNdDescToXeVMPattern
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
- // pointer type is passed as i32 or i64 by type converter.
+ // Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
@@ -190,7 +190,7 @@ class CreateNdDescToXeVMPattern
} else {
baseAddr = adaptor.getSource();
}
- // utility for creating offset values from op fold result.
+ // Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
unsigned idx) -> Value {
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
@@ -212,10 +212,10 @@ class CreateNdDescToXeVMPattern
baseShapeW = createOffset(mixedSizes, rank - 1);
baseShapeH = createOffset(mixedSizes, rank - 2);
if (sourceMemrefTy) {
- // cast index to i64.
+ // Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
} else if (baseAddr.getType() != i64Ty) {
- // pointer type may be i32. Cast to i64 if needed.
+ // Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
// Populate payload.
@@ -255,7 +255,7 @@ class UpdateNdOffsetToXeVMPattern
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto tdesc = adaptor.getTensorDesc();
- // utility for updating payload offset values from op fold result.
+ // Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
@@ -425,7 +425,7 @@ class CreateDescToXeVMPattern
op, "Expected element type bit width to be multiple of 8.");
}
auto loc = op.getLoc();
- // offsets are provided as scalar i64 by type converter.
+ // Offsets are provided as scalar i64 by type converter.
auto offsets = adaptor.getOffsets();
// Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
// But type converter will convert them to integer types.
@@ -453,8 +453,8 @@ class UpdateOffsetToXeVMPattern
op, "Expected element type bit width to be multiple of 8.");
}
auto loc = op.getLoc();
- // scatter descriptor is provided as scalar i64 by type converter.
- // offsets are provided as scalar i64 by type converter.
+ // Scatter descriptor is provided as scalar i64 by type converter.
+ // Offsets are provided as scalar i64 by type converter.
Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
adaptor.getOffsets(), eBw / 8);
rewriter.replaceOp(op, newOffset);
@@ -583,7 +583,7 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
rewriter.replaceOp(op, ifOp.getResult(0));
} else {
- // if mask is true, perform the store.
+ // If mask is true, perform the store.
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
auto body = ifOp.getBody();
rewriter.setInsertionPointToStart(body);
@@ -758,7 +758,7 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
if (cvecty != cNty)
c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
- // below are uArch dependent values, should move away from hardcoding
+ // Below are uArch dependent values, should move away from hardcoding
constexpr int32_t systolicDepth{8};
constexpr int32_t executionSize{16};
Value dpasRes = xevm::MMAOp::create(
@@ -818,7 +818,6 @@ matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
default:
return std::nullopt;
}
- llvm_unreachable("Invalid AtomicRMWKind");
}
class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
>From 953b8508152bab32cd905cd582c29c5340adcf2f Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 26 Aug 2025 01:00:46 +0000
Subject: [PATCH 11/17] Remove unneeded llvm_unreachable.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6cd50a38a21a4..172b09bacdc03 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -55,7 +55,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
- llvm_unreachable("Unknown XeGPU memory space.");
}
// Get same bitwidth flat vector type of new element type.
@@ -689,7 +688,6 @@ class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
case xegpu::FenceScope::GPU:
memScope = xevm::MemScope::DEVICE;
break;
- llvm_unreachable("Unknown XeGPU fence scope.");
}
xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
switch (op.getMemoryKind()) {
@@ -699,7 +697,6 @@ class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
case xegpu::MemorySpace::SLM:
addrSpace = xevm::AddrSpace::SHARED;
break;
- llvm_unreachable("Unknown XeGPU fence scope.");
}
xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
rewriter.eraseOp(op);
>From 10a6aff857c9a033f290a1ed1fcf73f466a24cb8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 26 Aug 2025 01:21:25 +0000
Subject: [PATCH 12/17] Remove redundant braces.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 54 +++++++------------
1 file changed, 20 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 172b09bacdc03..ae42489196303 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -210,13 +210,13 @@ class CreateNdDescToXeVMPattern
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, rank - 1);
baseShapeH = createOffset(mixedSizes, rank - 2);
- if (sourceMemrefTy) {
+ if (sourceMemrefTy)
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else if (baseAddr.getType() != i64Ty) {
+ else if (baseAddr.getType() != i64Ty)
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
- }
+
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -288,9 +288,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2) {
+ if (tdescTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
- }
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
@@ -308,10 +307,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value offsetH;
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 0 && opOffsetsSize != 2) {
+ if (opOffsetsSize != 0 && opOffsetsSize != 2)
return rewriter.notifyMatchFailure(op,
"Expected 2D offsets or no offsets.");
- }
if (opOffsetsSize) {
// If mixed offsets are provided by the op convert them to i32.
offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
@@ -348,10 +346,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
- if (!srcVecTy) {
+ if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
- }
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
Value src = adaptor.getValue();
@@ -419,10 +416,9 @@ class CreateDescToXeVMPattern
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0) {
+ if (eBw % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
- }
auto loc = op.getLoc();
// Offsets are provided as scalar i64 by type converter.
auto offsets = adaptor.getOffsets();
@@ -447,10 +443,9 @@ class UpdateOffsetToXeVMPattern
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
auto eBw = eTy.getIntOrFloatBitWidth();
- if (eBw % 8 != 0) {
+ if (eBw % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
- }
auto loc = op.getLoc();
// Scatter descriptor is provided as scalar i64 by type converter.
// Offsets are provided as scalar i64 by type converter.
@@ -475,30 +470,27 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrI64;
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
- if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
valOrResTy = op.getResult().getType();
- } else {
+ else
valOrResTy = adaptor.getValue().getType();
- }
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
bool hasScalarVal = !valOrResVecTy;
int64_t elemBitWidth =
hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
: valOrResVecTy.getElementType().getIntOrFloatBitWidth();
// Element type must be multiple of 8 bits.
- if (elemBitWidth % 8 != 0) {
+ if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
- }
int64_t elemByteSize = elemBitWidth / 8;
// Default memory space is global.
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
// If tensor descriptor is available, we use its memory space.
- if (tdescTy) {
+ if (tdescTy)
ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- }
// Base pointer can come from source (load) or dest (store).
// If they are memrefs, we use their memory space.
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
@@ -524,18 +516,17 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
Value offsets = adaptor.getOffsets();
Value mask = adaptor.getMask();
if (offsets) {
- if (dyn_cast<VectorType>(offsets.getType())) {
+ if (dyn_cast<VectorType>(offsets.getType()))
// Offset needs be scalar. Single element vector is converted to scalar
// by type converter.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
- } else {
+ else
// If offsets are provided, we add them to the base pointer.
// Offsets are in number of elements, we need to multiply by
// element byte size.
basePtrI64 =
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
- }
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -543,13 +534,12 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
Value maskForLane;
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
- if (maskVecTy) {
+ if (maskVecTy)
// Mask needs be scalar. Single element vector is converted to scalar by
// type converter.
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
- } else {
+ else
maskForLane = mask;
- }
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
maskForLane, true, true);
@@ -609,10 +599,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
auto tdescTy = op.getTensorDescType();
Value basePtrI64 = adaptor.getSource();
// Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
- if (basePtrI64.getType() != rewriter.getI64Type()) {
+ if (basePtrI64.getType() != rewriter.getI64Type())
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
- }
Value offsets = adaptor.getOffsets();
if (offsets) {
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
@@ -637,10 +626,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
elemByteSize = *op.getOffsetAlignByte();
}
if (elemBitWidth != 0) {
- if (elemBitWidth % 8 != 0) {
+ if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
- }
elemByteSize = elemBitWidth / 8;
}
basePtrI64 =
@@ -651,10 +639,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
// If tensor descriptor is available, we use its memory space.
- if (tdescTy) {
+ if (tdescTy)
ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- }
// If source is a memref, we use its memory space.
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
auto addrSpace = memRefTy.getMemorySpaceAsInt();
@@ -883,9 +870,8 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
- if (type.isScattered()) {
+ if (type.isScattered())
return IntegerType::get(&getContext(), 64);
- }
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
>From dea2933d2db7100cc84556f6d9321f3a82bba331 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 26 Aug 2025 14:52:08 +0000
Subject: [PATCH 13/17] Add element bitwidth restriction.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index ae42489196303..712ed1ee88988 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -290,6 +290,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tdescTy = op.getTensorDescType();
if (tdescTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto elemType = tdescTy.getElementType();
+ auto elemBitSize = elemType.getIntOrFloatBitWidth();
+ if (elemBitSize % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
@@ -333,8 +338,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute element byte size and surface width in bytes.
- auto elemType = tdescTy.getElementType();
- auto elemBitSize = elemType.getIntOrFloatBitWidth();
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
>From b01086ed5279acc7018897708808493e8a4b3d98 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 27 Aug 2025 17:19:37 +0000
Subject: [PATCH 14/17] Address comments.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 712ed1ee88988..8d2ad4e999d38 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -151,10 +151,11 @@ class CreateNdDescToXeVMPattern
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
- // Payload is a 8xi32 vector.
+ // Payload is a 8xi32 vector. Offset to individual fields are defined in
+ // NdTdescOffset enum.
Type payloadElemTy = rewriter.getI32Type();
- Type i64Ty = rewriter.getI64Type();
VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ Type i64Ty = rewriter.getI64Type();
// 4xi64 view is used for inserting the base pointer.
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
// Initialize payload to zero.
@@ -180,12 +181,12 @@ class CreateNdDescToXeVMPattern
// If source is a memref, we need to extract the aligned pointer as index.
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
- baseAddr =
- memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
if (!sourceMemrefTy.hasStaticShape()) {
op.emitError() << "Expected static memref shape.";
return failure();
}
+ baseAddr =
+ memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
} else {
baseAddr = adaptor.getSource();
}
@@ -198,8 +199,8 @@ class CreateNdDescToXeVMPattern
};
// Offsets can be either 2D or not provided (0 is used).
if (mixedOffsets.size() == 2) {
- offsetW = createOffset(mixedOffsets, rank - 1);
- offsetH = createOffset(mixedOffsets, rank - 2);
+ offsetW = createOffset(mixedOffsets, 1);
+ offsetH = createOffset(mixedOffsets, 0);
} else if (mixedOffsets.size() == 0) {
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
@@ -208,8 +209,8 @@ class CreateNdDescToXeVMPattern
"Expected 2D offsets or no offsets.");
}
// Get shape values from op fold results.
- baseShapeW = createOffset(mixedSizes, rank - 1);
- baseShapeH = createOffset(mixedSizes, rank - 2);
+ baseShapeW = createOffset(mixedSizes, 1);
+ baseShapeH = createOffset(mixedSizes, 0);
if (sourceMemrefTy)
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
>From a2887f24e22dc1293f4c8851ae1d78748a357d17 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 27 Aug 2025 18:25:59 +0000
Subject: [PATCH 15/17] Address comments.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 20 +++---
.../XeGPUToXeVM/create_nd_tdesc.mlir | 64 ++++++++++++-------
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 3 +-
3 files changed, 53 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8d2ad4e999d38..187d8d805a06e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1,4 +1,4 @@
-//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//===-- XeGPUToXeVM.cpp - XeVM to LLVM dialect conversion -------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -182,8 +182,7 @@ class CreateNdDescToXeVMPattern
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
if (!sourceMemrefTy.hasStaticShape()) {
- op.emitError() << "Expected static memref shape.";
- return failure();
+ return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
@@ -211,13 +210,13 @@ class CreateNdDescToXeVMPattern
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
- if (sourceMemrefTy)
+ if (sourceMemrefTy) {
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- else if (baseAddr.getType() != i64Ty)
+ } else if (baseAddr.getType() != i64Ty) {
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
-
+ }
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -520,17 +519,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
Value offsets = adaptor.getOffsets();
Value mask = adaptor.getMask();
if (offsets) {
- if (dyn_cast<VectorType>(offsets.getType()))
+ if (dyn_cast<VectorType>(offsets.getType())) {
// Offset needs be scalar. Single element vector is converted to scalar
// by type converter.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
- else
+ } else {
// If offsets are provided, we add them to the base pointer.
// Offsets are in number of elements, we need to multiply by
// element byte size.
basePtrI64 =
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ }
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -538,11 +538,11 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
Value maskForLane;
VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
- if (maskVecTy)
+ if (maskVecTy) {
// Mask needs be scalar. Single element vector is converted to scalar by
// type converter.
return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
- else
+ } else
maskForLane = mask;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index ba7ece8ccbebe..4ff95b40fe68c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,24 +2,24 @@
gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64
- // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
+ // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+ // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
- %stride1: index, %stride2: index) kernel {
+ %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
- // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+ // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
- // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
- // CHECK: %[[VAR3:.*]] = arith.index_cast %[[ARG3]] : index to i32
- // CHECK: %[[VAR5:.*]] = arith.index_cast %[[ARG2]] : index to i32
+ // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
+ // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
+ // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
+ // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR9:.*]] = vector.insert %[[VAR3]], %[[VAR8]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR10:.*]] = vector.insert %[[VAR5]], %[[VAR9]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR11:.*]] = vector.insert %[[C0_I32]], %[[VAR10]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR12:.*]] = vector.insert %[[C0_I32_0]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
@@ -28,21 +28,39 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
- // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32
- // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32
+ // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
- // CHECK: %[[C16_I32:.*]] = arith.trunci %c16_i64 : i64 to i32
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
// CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
- // CHECK: %[[C8_I32:.*]] = arith.trunci %c8_i64 : i64 to i32
- // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+ // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR17:.*]] = vector.insert %[[C16_I32]], %[[VAR16]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR18:.*]] = vector.insert %[[C8_I32]], %[[VAR17]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR19:.*]] = vector.insert %[[C0_I32_2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR20:.*]] = vector.insert %[[C0_I32_3]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
+ // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
+ // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
+ // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+ // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
+ // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index 15940fc4aca26..e6f22f0a9acbb 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -5,7 +5,8 @@
#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
gpu.module @load_store_check {
- //CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>) -> vector<8xf32>
+ // CHECK-LABEL: func.func @dpas(
+ // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
// Loads are checked in a separate test.
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
>From a6833962ae48c49f2583165985d7db1e2761f852 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 27 Aug 2025 18:32:33 +0000
Subject: [PATCH 16/17] Add test case description.
---
mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index 2445c4b341657..b28a8c2ccf843 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s
+// This file contains tests for materalization patterns added to handle custom type conversions
+// added on top of LLVM type converter.
+
gpu.module @materializecast {
// CHECK-LABEL: gpu.func @materialize_memref
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
>From 8f51ef433f5f0553a130df82d909fd63a0712f62 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 27 Aug 2025 21:31:22 +0000
Subject: [PATCH 17/17] Address comments.
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 187d8d805a06e..906b943a98756 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -30,6 +30,8 @@
#include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
+
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -39,6 +41,10 @@ using namespace mlir;
namespace {
+// TODO: Below are uArch dependent values, should move away from hardcoding
+static constexpr int32_t systolicDepth{8};
+static constexpr int32_t executionSize{16};
+
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
BasePtr = 0, // Base pointer (i64)
@@ -746,9 +752,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
if (cvecty != cNty)
c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
- // Below are uArch dependent values, should move away from hardcoding
- constexpr int32_t systolicDepth{8};
- constexpr int32_t executionSize{16};
Value dpasRes = xevm::MMAOp::create(
rewriter, loc, cNty, aVec, bVec, c,
xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
@@ -867,10 +870,9 @@ struct ConvertXeGPUToXeVMPass
if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
- unsigned sum = 1;
- for (unsigned i = 0; i < rank; i++) {
- sum *= type.getShape()[i];
- }
+ int64_t sum =
+ std::accumulate(type.getShape().begin(), type.getShape().end(),
+ int64_t{1}, std::multiplies<int64_t>());
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
More information about the Mlir-commits
mailing list