[Mlir-commits] [mlir] [mlir][gpu]Add GPUToXeVM lowering pipeline pass. (PR #161216)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Fri Oct 17 11:22:06 PDT 2025
https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/161216
>From 238cc7a27dd6cc6fc19b3ded8b10e8e99b257a77 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Mon, 29 Sep 2025 15:21:56 +0000
Subject: [PATCH 01/16] Add GPUToXeVM lowering pipeline pass.
It's the default GPU to XeVM lowering pipeline. It starts by lowering GPU
code to the specified compilation target (default is fatbin),
then lowers the host code.
If XeGPU ops are used, it expects the MLIR code to have
XeGPU ops already embedded in gpu code.
---
.../mlir/Dialect/GPU/Pipelines/Passes.h | 50 ++++++-
mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 5 +
.../GPU/Pipelines/GPUToXeVMPipeline.cpp | 138 ++++++++++++++++++
mlir/lib/RegisterAllPasses.cpp | 1 +
4 files changed, 193 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 035235fc7174a..c634236139d6f 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
+//===- Passes.h - GPU NVVM/XeVM pipeline entry points----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -60,6 +60,47 @@ struct GPUToNVVMPipelineOptions
llvm::cl::init(false)};
};
+// Options for the gpu to xevm pipeline.
+struct GPUToXeVMPipelineOptions
+ : public PassPipelineOptions<GPUToXeVMPipelineOptions> {
+ // General lowering controls.
+ PassOptions::Option<int64_t> indexBitWidth{
+ *this, "index-bitwidth",
+ llvm::cl::desc("Bitwidth of the index type (host & device)"),
+ llvm::cl::init(64)};
+ PassOptions::Option<bool> kernelBarePtrCallConv{
+ *this, "kernel-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for device kernels"),
+ llvm::cl::init(false)};
+ PassOptions::Option<bool> hostBarePtrCallConv{
+ *this, "host-bare-ptr-calling-convention",
+ llvm::cl::desc("Use bare pointer calling convention for host launches"),
+ llvm::cl::init(false)};
+ PassOptions::Option<std::string> binaryFormat{
+ *this, "binary-format",
+ llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"),
+ llvm::cl::init("fatbin")};
+ // Options mirroring xevm-attach-target (GpuXeVMAttachTarget).
+ PassOptions::Option<std::string> xevmModuleMatcher{
+ *this, "xevm-module-matcher",
+ llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"),
+ llvm::cl::init("")};
+ PassOptions::Option<std::string> zebinTriple{
+ *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"),
+ llvm::cl::init("spirv64-unknown-unknown")};
+ PassOptions::Option<std::string> zebinChip{
+ *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"),
+ llvm::cl::init("bmg")};
+ PassOptions::Option<unsigned> optLevel{
+ *this, "opt-level",
+ llvm::cl::desc("Optimization level for attached target/codegen"),
+ llvm::cl::init(2)};
+ PassOptions::Option<std::string> cmdOptions{
+ *this, "igc-cmd-options",
+ llvm::cl::desc("Additional downstream compiler command line options"),
+ llvm::cl::init("")};
+};
+
//===----------------------------------------------------------------------===//
// Building and Registering.
//===----------------------------------------------------------------------===//
@@ -70,8 +111,15 @@ struct GPUToNVVMPipelineOptions
void buildLowerToNVVMPassPipeline(OpPassManager &pm,
const GPUToNVVMPipelineOptions &options);
+/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main
+/// dialects into XeVM targets. Begins with GPU code regions, then handles host
+/// code.
+void buildLowerToXeVMPassPipeline(OpPassManager &pm,
+ const GPUToXeVMPipelineOptions &options);
+
/// Register all pipeleines for the `gpu` dialect.
void registerGPUToNVVMPipeline();
+void registerGPUToXeVMPipeline();
} // namespace gpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index 70a9c77a6d796..f231eebe6d82b 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRGPUPipelines
GPUToNVVMPipeline.cpp
+ GPUToXeVMPipeline.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRLinalgTransforms
MLIRAffineToStandard
MLIRGPUToNVVMTransforms
+ MLIRXeGPUToXeVM
MLIRIndexToLLVM
MLIRMathToLLVM
MLIRNVGPUToNVVM
@@ -19,4 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRVectorToSCF
+ MLIRXeGPUTransforms
+ MLIRXeGPUToXeVM
+ MLIRXeVMToLLVM
)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
new file mode 100644
index 0000000000000..eedd11df7f8af
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -0,0 +1,138 @@
+//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to XeVM as a generally
+// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have
+// XeGPU ops already embedded in gpu code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Common pipeline
+//===----------------------------------------------------------------------===//
+void buildCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes
+ pm.addPass(createCSEPass());
+ {
+ GpuXeVMAttachTargetOptions xevmTargetOptions;
+ xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
+ xevmTargetOptions.triple = options.zebinTriple;
+ xevmTargetOptions.chip = options.zebinChip;
+ xevmTargetOptions.optLevel = options.optLevel;
+ xevmTargetOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// GPUModule-specific stuff.
+//===----------------------------------------------------------------------===//
+void buildGpuPassPipeline(OpPassManager &pm,
+ const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
+ ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
+ gpuToLLVMSPVOptions.use64bitIndex = options.indexBitWidth;
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeVMToLLVMPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Host Post-GPU pipeline
+//===----------------------------------------------------------------------===//
+void buildHostPostPipeline(OpPassManager &pm,
+ const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ pm.addNestedPass<func::FuncOp>(LLVM::createLLVMRequestCWrappersPass());
+ pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
+ pm.addPass(createConvertVectorToSCFPass());
+ pm.addPass(createSCFToControlFlowPass());
+ pm.addPass(memref::createExpandStridedMetadataPass());
+ pm.addPass(createFinalizeMemRefToLLVMConversionPass());
+ {
+ GpuToLLVMConversionPassOptions gpuToLLVMOptions;
+ gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
+ gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
+ pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
+ }
+ pm.addPass(createConvertToLLVMPass());
+ pm.addPass(createLowerAffinePass());
+ // gpu-module-to-binary
+ {
+ GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
+ gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
+ gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
+ pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
+ }
+ pm.addPass(createReconcileUnrealizedCastsPass());
+}
+} // namespace
+
+void mlir::gpu::buildLowerToXeVMPassPipeline(
+ OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
+ // Common pipelines
+ buildCommonPassPipeline(pm, options);
+
+ // GPUModule-specific stuff
+ buildGpuPassPipeline(pm, options);
+
+ // Host post-GPUModule-specific stuff
+ buildHostPostPipeline(pm, options);
+}
+
+void mlir::gpu::registerGPUToXeVMPipeline() {
+ PassPipelineRegistration<GPUToXeVMPipelineOptions>(
+ "gpu-lower-to-xevm-pipeline",
+ "The default GPU to XeVM lowering pipeline. It starts by lowering GPU "
+ "code to the "
+ "specified compilation target (default is fatbin) then lowers the host "
+ "code.",
+ buildLowerToXeVMPassPipeline);
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index c67b24226ae45..dd413d2de8710 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -98,4 +98,5 @@ void mlir::registerAllPasses() {
sparse_tensor::registerSparseTensorPipelines();
tosa::registerTosaToLinalgPipelines();
gpu::registerGPUToNVVMPipeline();
+ gpu::registerGPUToXeVMPipeline();
}
>From 17bade6a5e57b9d06c6f03e9b8252a322252665d Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 30 Sep 2025 00:07:45 +0000
Subject: [PATCH 02/16] Add a pass option to provide the XeGPU code level.
XeGPU allows worgroup, subgroup, and workitem level programming.
This options lets the pass manager know at which level the
XeGPU ops belong to.
---
.../mlir/Dialect/GPU/Pipelines/Passes.h | 12 ++++--
.../GPU/Pipelines/GPUToXeVMPipeline.cpp | 42 +++++++++++--------
2 files changed, 33 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index c634236139d6f..67dc6415008f3 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -63,11 +63,17 @@ struct GPUToNVVMPipelineOptions
// Options for the gpu to xevm pipeline.
struct GPUToXeVMPipelineOptions
: public PassPipelineOptions<GPUToXeVMPipelineOptions> {
+ // XeGPU op granularity selection: workgroup | subgroup | workitem
+ PassOptions::Option<std::string> xegpuOpLevel{
+ *this, "xegpu-op-level",
+ llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
+ "subgroup | workitem"),
+ llvm::cl::init("workgroup")};
// General lowering controls.
- PassOptions::Option<int64_t> indexBitWidth{
- *this, "index-bitwidth",
+ PassOptions::Option<bool> use64bitIndex{
+ *this, "use-64bit-index",
llvm::cl::desc("Bitwidth of the index type (host & device)"),
- llvm::cl::init(64)};
+ llvm::cl::init(true)};
PassOptions::Option<bool> kernelBarePtrCallConv{
*this, "kernel-bare-ptr-calling-convention",
llvm::cl::desc("Use bare pointer calling convention for device kernels"),
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index eedd11df7f8af..ae77143f6a66d 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,24 +62,30 @@ void buildCommonPassPipeline(
//===----------------------------------------------------------------------===//
void buildGpuPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
- pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
- pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
- pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+ if (options.xegpuOpLevel == "workgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ }
+ if (options.xegpuOpLevel == "subgroup") {
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+ }
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
- ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
- gpuToLLVMSPVOptions.use64bitIndex = options.indexBitWidth;
- pm.addNestedPass<gpu::GPUModuleOp>(
- createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+ {
+ ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
+ gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
+ pm.addNestedPass<gpu::GPUModuleOp>(
+ createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+ }
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeVMToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
@@ -104,6 +110,7 @@ void buildHostPostPipeline(OpPassManager &pm,
}
pm.addPass(createConvertToLLVMPass());
pm.addPass(createLowerAffinePass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
// gpu-module-to-binary
{
GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
@@ -111,7 +118,6 @@ void buildHostPostPipeline(OpPassManager &pm,
gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
}
- pm.addPass(createReconcileUnrealizedCastsPass());
}
} // namespace
>From 75d5547915cad8feea26ff8803442db8dd614287 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 30 Sep 2025 00:10:16 +0000
Subject: [PATCH 03/16] Add a workitem level gemm test for XeGPU.
---
.../Dialect/XeGPU/SIMT/simple_gemm.mlir | 123 ++++++++++++++++++
1 file changed, 123 insertions(+)
create mode 100644 mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
new file mode 100644
index 0000000000000..ddae9c1e7eb8f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
@@ -0,0 +1,123 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workitem" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %block_x = gpu.block_id x
+ %block_y = gpu.block_id y
+ %x_block_offset = arith.muli %block_x, %c8 : index
+ %y_block_offset = arith.muli %block_y, %c16 : index
+
+ %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+
+ %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) {
+
+ %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+ %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+ %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+ scf.yield %dpas : vector<8xf32>
+ }
+ xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+ }
+ }
+
+ func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %memref_a = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16>
+ %memref_b = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16>
+ %memref_c = gpu.alloc () : memref<256x256xf32>
+ gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32>
+ gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
+ gpu.wait // Wait for the kernel to finish.
+ gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32>
+ gpu.dealloc %memref_a : memref<256x256xf16>
+ gpu.dealloc %memref_b : memref<256x256xf16>
+ gpu.dealloc %memref_c : memref<256x256xf32>
+ return %c : memref<256x256xf32>
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_f16 = arith.constant 1.0 : f16
+ %c2_f16 = arith.constant 2.0 : f16
+ %c256 = arith.constant 256 : index
+ %cf_0 = arith.constant 0.0 : f16
+ %cf_1 = arith.constant 1.0 : f16
+ %A = memref.alloc() : memref<256x256xf16>
+ %B = memref.alloc() : memref<256x256xf16>
+ %C = memref.alloc() : memref<256x256xf32>
+ %C_ref = memref.alloc() : memref<256x256xf32>
+ %c_gen_int = arith.constant 0 : i1
+ %cf_lower = arith.constant -0.5 : f32
+ %cf_upper = arith.constant 0.5 : f32
+
+ // Initialize matrix A ; A[i, j] = j
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %t = index.castu %j : index to i16
+ %val = arith.uitofp %t : i16 to f16
+ memref.store %val, %A[%i, %j] : memref<256x256xf16>
+ }
+ }
+
+ // Initialize the B matrix.
+ // Make matrix B an identity matrix.
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %i_i32 = index.castu %i : index to i32
+ %j_i32 = index.castu %j : index to i32
+ %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+ scf.if %i_j_same {
+ memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+ } else {
+ memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+ }
+ }
+ }
+
+ // Initialize matrix C and C_ref ; C[i, j] = 0
+ %c0_f32 = arith.constant 0.0 : f32
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+ memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+ }
+ }
+
+ // Run GPU version.
+ %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+ %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
+ call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> ()
+ memref.dealloc %A : memref<256x256xf16>
+ memref.dealloc %B : memref<256x256xf16>
+ memref.dealloc %C : memref<256x256xf32>
+ memref.dealloc %C_ref : memref<256x256xf32>
+ return
+ }
+ func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
+ func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
>From 3072c1c14768a8a106fabcce27ba06a732568944 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 30 Sep 2025 15:05:08 +0000
Subject: [PATCH 04/16] Fix a small logic.
---
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index ae77143f6a66d..e3e1bab3e12d1 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -70,7 +70,8 @@ void buildGpuPassPipeline(OpPassManager &pm,
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
- if (options.xegpuOpLevel == "subgroup") {
+ if (options.xegpuOpLevel == "subgroup" ||
+ options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
>From 166e348e7197b93466e0e827bcd366698eb664f7 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 30 Sep 2025 21:57:55 +0000
Subject: [PATCH 05/16] Add test cases for SG and WG.
---
.../Dialect/XeGPU/SG/simple_gemm.mlir | 120 ++++++++++++++
.../Dialect/XeGPU/SIMT/simple_gemm.mlir | 2 -
.../Dialect/XeGPU/WG/simple_gemm.mlir | 149 ++++++++++++++++++
3 files changed, 269 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
create mode 100644 mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
diff --git a/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
new file mode 100644
index 0000000000000..877edf47fcd15
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=subgroup" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+ gpu.module @kernel {
+ gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %block_x = gpu.block_id x
+ %block_y = gpu.block_id y
+ %x_block_offset = arith.muli %block_x, %c8 : index
+ %y_block_offset = arith.muli %block_y, %c16 : index
+
+ %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+
+ %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8x16xf32>) {
+ %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %dpas : vector<8x16xf32>
+ }
+ xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.return
+ }
+ }
+
+ func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %memref_a = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16>
+ %memref_b = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16>
+ %memref_c = gpu.alloc () : memref<256x256xf32>
+ gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32>
+ gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
+ gpu.wait // Wait for the kernel to finish.
+ gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32>
+ gpu.dealloc %memref_a : memref<256x256xf16>
+ gpu.dealloc %memref_b : memref<256x256xf16>
+ gpu.dealloc %memref_c : memref<256x256xf32>
+ return %c : memref<256x256xf32>
+ }
+
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_f16 = arith.constant 1.0 : f16
+ %c2_f16 = arith.constant 2.0 : f16
+ %c256 = arith.constant 256 : index
+ %cf_0 = arith.constant 0.0 : f16
+ %cf_1 = arith.constant 1.0 : f16
+ %A = memref.alloc() : memref<256x256xf16>
+ %B = memref.alloc() : memref<256x256xf16>
+ %C = memref.alloc() : memref<256x256xf32>
+ %C_ref = memref.alloc() : memref<256x256xf32>
+ %c_gen_int = arith.constant 0 : i1
+ %cf_lower = arith.constant -0.5 : f32
+ %cf_upper = arith.constant 0.5 : f32
+ // Option 1: intialize matrix A ; A[i, j] = j
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %t = index.castu %j : index to i16
+ %val = arith.uitofp %t : i16 to f16
+ memref.store %val, %A[%i, %j] : memref<256x256xf16>
+ }
+ }
+
+ // Initialize the B matrix
+ // Make matrix B an identity matrix
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %i_i32 = index.castu %i : index to i32
+ %j_i32 = index.castu %j : index to i32
+ %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+ scf.if %i_j_same {
+ memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+ } else {
+ memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+ }
+ }
+ }
+ // intialize matrix C and C_ref ; C[i, j] = 0
+ %c0_f32 = arith.constant 0.0 : f32
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+ memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+ }
+ }
+
+ // Run GPU.
+ %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+ %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
+ call @printMemrefF32(%cast_C) : (memref<*xf32>) -> ()
+
+ memref.dealloc %A : memref<256x256xf16>
+ memref.dealloc %B : memref<256x256xf16>
+ memref.dealloc %C : memref<256x256xf32>
+ memref.dealloc %C_ref : memref<256x256xf32>
+ return
+ }
+ func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
index ddae9c1e7eb8f..36b04791ee2dd 100644
--- a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
+++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
@@ -25,7 +25,6 @@ module @gemm attributes {gpu.container_module} {
%b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
%r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) {
-
%a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
%b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
%dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
@@ -118,6 +117,5 @@ module @gemm attributes {gpu.container_module} {
memref.dealloc %C_ref : memref<256x256xf32>
return
}
- func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
}
diff --git a/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
new file mode 100644
index 0000000000000..fdc24c02f9d98
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
@@ -0,0 +1,149 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workgroup" \
+// RUN: | mlir-runner \
+// RUN: --shared-libs=%mlir_levelzero_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#a = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+#b = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [16, 16]>
+#c = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>
+#a_prefetch = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32], inst_data = [8, 16]>
+#b_prefetch = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32], inst_data = [8, 16]>
+module @gemm attributes {gpu.container_module} {
+ func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c512 = arith.constant 512 : index
+ %A_gpu = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %A_gpu, %A : memref<256x256xf16>, memref<256x256xf16>
+ %B_gpu = gpu.alloc () : memref<256x256xf16>
+ gpu.memcpy %B_gpu, %B : memref<256x256xf16>, memref<256x256xf16>
+ %C_gpu = gpu.alloc () : memref<256x256xf32>
+ gpu.memcpy %C_gpu, %C : memref<256x256xf32>, memref<256x256xf32>
+ // NOTE: Here we can't use [8, 64] wi threads following the SG thread layout of [8, 4]. Because runtime will linearize the x dimension first (we need y dimension to be linearized first).
+ // So just use linearized thread layout of [512, 1] wi threads.
+ gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>)
+ gpu.wait // Wait for the kernel to finish.
+ gpu.memcpy %C, %C_gpu : memref<256x256xf32>, memref<256x256xf32>
+ gpu.dealloc %A_gpu : memref<256x256xf16>
+ gpu.dealloc %B_gpu : memref<256x256xf16>
+ gpu.dealloc %C_gpu : memref<256x256xf32>
+ return %C : memref<256x256xf32>
+ }
+
+ gpu.module @test_kernel {
+ gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c96 = arith.constant 96 : index
+ %c256 = arith.constant 256 : index
+ %c4096 = arith.constant 4096 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %m = arith.muli %block_id_x, %c256 : index
+ %n = arith.muli %block_id_y, %c256 : index
+ %c_tdesc = xegpu.create_nd_tdesc %C : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #c>
+ %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32>
+ %a_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a>
+ %b_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b>
+ // Prefetch A 3 times.
+ %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+ xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+ xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+ xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+ // Prefetch B 3 times.
+ %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+ xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+ xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+ xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+
+ %out = scf.for %k = %c0 to %c256 step %c32
+ iter_args(%c_value = %c_init_value)
+ -> (vector<256x256xf32>) {
+ %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16>
+ %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16>
+ // Prefetch next tiles.
+ %prefetch_offset = arith.addi %k, %c96 : index
+ xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+ xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+ %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c}
+ : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32>
+ scf.yield %c_new_value : vector<256x256xf32>
+ }
+ xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c>
+ gpu.return
+ }
+ }
+
+ func.func @main() attributes {llvm.emit_c_interface} {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_f16 = arith.constant 1.0 : f16
+ %c2_f16 = arith.constant 2.0 : f16
+ %c256 = arith.constant 256 : index
+ %cf_0 = arith.constant 0.0 : f16
+ %cf_1 = arith.constant 1.0 : f16
+ %A = memref.alloc() : memref<256x256xf16>
+ %B = memref.alloc() : memref<256x256xf16>
+ %C = memref.alloc() : memref<256x256xf32>
+ %C_ref = memref.alloc() : memref<256x256xf32>
+ %c_gen_int = arith.constant 0 : i1
+ %cf_lower = arith.constant -0.5 : f32
+ %cf_upper = arith.constant 0.5 : f32
+ // Intialize matrix A ; A[i, j] = j
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %t = index.castu %j : index to i16
+ %val = arith.uitofp %t : i16 to f16
+ memref.store %val, %A[%i, %j] : memref<256x256xf16>
+ }
+ }
+
+ // Initialize the B matrix
+ // Make matrix B an identity matrix
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ %i_i32 = index.castu %i : index to i32
+ %j_i32 = index.castu %j : index to i32
+ %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+ scf.if %i_j_same {
+ memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+ } else {
+ memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+ }
+ }
+ }
+
+ // Initialize matrix C and C_ref ; C[i, j] = 0
+ %c0_f32 = arith.constant 0.0 : f32
+ scf.for %i = %c0 to %c256 step %c1 {
+ scf.for %j = %c0 to %c256 step %c1 {
+ memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+ memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+ }
+ }
+
+ // Run GPU version.
+ %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+ %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+ // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+ // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
+ call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> ()
+
+ memref.dealloc %A : memref<256x256xf16>
+ memref.dealloc %B : memref<256x256xf16>
+ memref.dealloc %C : memref<256x256xf32>
+ memref.dealloc %C_ref : memref<256x256xf32>
+ return
+ }
+ func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}
>From 3ebee991458ab3f21db4607d31b9992172fd3eae Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Fri, 3 Oct 2025 16:32:22 +0000
Subject: [PATCH 06/16] Address review comments.
Change `workitem` to `lane`.
---
mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h | 4 ++--
mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 67dc6415008f3..c545517ec7739 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -63,11 +63,11 @@ struct GPUToNVVMPipelineOptions
// Options for the gpu to xevm pipeline.
struct GPUToXeVMPipelineOptions
: public PassPipelineOptions<GPUToXeVMPipelineOptions> {
- // XeGPU op granularity selection: workgroup | subgroup | workitem
+ // XeGPU op granularity selection: workgroup | subgroup | lane
PassOptions::Option<std::string> xegpuOpLevel{
*this, "xegpu-op-level",
llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
- "subgroup | workitem"),
+ "subgroup | lane"),
llvm::cl::init("workgroup")};
// General lowering controls.
PassOptions::Option<bool> use64bitIndex{
diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
index 36b04791ee2dd..ffe29ef35a5c9 100644
--- a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
+++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workitem" \
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_levelzero_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
>From bf75132dfea4bf54df22d94c5511c1ea77773c6c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 7 Oct 2025 14:06:46 +0000
Subject: [PATCH 07/16] Address review comments.
Rename `SIMT` test folder to `LANE`.
---
.../Integration/Dialect/XeGPU/{SIMT => LANE}/simple_gemm.mlir | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename mlir/test/Integration/Dialect/XeGPU/{SIMT => LANE}/simple_gemm.mlir (100%)
diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
similarity index 100%
rename from mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir
rename to mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
>From f3ce5e3ffcbc9366fd441554343e7f4e65b5e037 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 8 Oct 2025 22:06:40 +0000
Subject: [PATCH 08/16] Address review comments.
Remove some unnecessary headers.
---
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index e3e1bab3e12d1..3911c6a135df0 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -13,12 +13,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
-#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
-#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
-#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
-#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
-#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
-#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
>From 0d4f233ca621544ddd9f1c7e4c4067c6b88a8fc5 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Mon, 13 Oct 2025 18:35:56 +0000
Subject: [PATCH 09/16] Address review comments.
---
mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h | 5 ++---
mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 1 -
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 10 ++++------
3 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index c545517ec7739..66d42fc6a1996 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU NVVM/XeVM pipeline entry points----------------------===//
+//===- Passes.h - GPU pipeline entry points----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -63,7 +63,6 @@ struct GPUToNVVMPipelineOptions
// Options for the gpu to xevm pipeline.
struct GPUToXeVMPipelineOptions
: public PassPipelineOptions<GPUToXeVMPipelineOptions> {
- // XeGPU op granularity selection: workgroup | subgroup | lane
PassOptions::Option<std::string> xegpuOpLevel{
*this, "xegpu-op-level",
llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
@@ -123,7 +122,7 @@ void buildLowerToNVVMPassPipeline(OpPassManager &pm,
void buildLowerToXeVMPassPipeline(OpPassManager &pm,
const GPUToXeVMPipelineOptions &options);
-/// Register all pipeleines for the `gpu` dialect.
+/// Register all pipelines for the `gpu` dialect.
void registerGPUToNVVMPipeline();
void registerGPUToXeVMPipeline();
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index f231eebe6d82b..a9d0540d0d504 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -13,7 +13,6 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRLinalgTransforms
MLIRAffineToStandard
MLIRGPUToNVVMTransforms
- MLIRXeGPUToXeVM
MLIRIndexToLLVM
MLIRMathToLLVM
MLIRNVGPUToNVVM
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 3911c6a135df0..b09e91f45447b 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -40,6 +40,7 @@ void buildCommonPassPipeline(
OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
// builtin.module scope passes
pm.addPass(createCSEPass());
+ pm.addPass(createConvertVectorToSCFPass());
{
GpuXeVMAttachTargetOptions xevmTargetOptions;
xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
@@ -49,6 +50,8 @@ void buildCommonPassPipeline(
xevmTargetOptions.cmdOptions = options.cmdOptions;
pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
}
+ pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
+ pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
}
//===----------------------------------------------------------------------===//
@@ -59,7 +62,6 @@ void buildGpuPassPipeline(OpPassManager &pm,
if (options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
@@ -90,21 +92,17 @@ void buildGpuPassPipeline(OpPassManager &pm,
//===----------------------------------------------------------------------===//
void buildHostPostPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
- pm.addNestedPass<func::FuncOp>(LLVM::createLLVMRequestCWrappersPass());
- pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
pm.addPass(createReconcileUnrealizedCastsPass());
- pm.addPass(createConvertVectorToSCFPass());
pm.addPass(createSCFToControlFlowPass());
pm.addPass(memref::createExpandStridedMetadataPass());
- pm.addPass(createFinalizeMemRefToLLVMConversionPass());
{
GpuToLLVMConversionPassOptions gpuToLLVMOptions;
gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
}
- pm.addPass(createConvertToLLVMPass());
pm.addPass(createLowerAffinePass());
+ pm.addPass(createConvertToLLVMPass());
pm.addPass(createReconcileUnrealizedCastsPass());
// gpu-module-to-binary
{
>From ea9be96137c0e591de50ec0d2a5b72468f5f5df1 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Mon, 13 Oct 2025 21:07:09 +0000
Subject: [PATCH 10/16] Add a missing dependency. Skip tests if certain
condition is not fulfilled.
---
mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 1 +
mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg | 4 ++++
mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg | 4 ++++
mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg | 4 ++++
4 files changed, 13 insertions(+)
create mode 100644 mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
create mode 100644 mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
create mode 100644 mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index a9d0540d0d504..262248fea89ca 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRTransforms
MLIRLinalgTransforms
MLIRAffineToStandard
+ MLIRGPUToLLVMSPV
MLIRGPUToNVVMTransforms
MLIRIndexToLLVM
MLIRMathToLLVM
diff --git a/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+ config.unsupported = True
+if not config.enable_levelzero_runner:
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+ config.unsupported = True
+if not config.enable_levelzero_runner:
+ config.unsupported = True
diff --git a/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+ config.unsupported = True
+if not config.enable_levelzero_runner:
+ config.unsupported = True
>From 14487e39092b9bf0c5363ced2a2a17edddb3ba45 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 14 Oct 2025 01:05:46 +0000
Subject: [PATCH 11/16] Address review comments.
Remove the explicit use of `convert-xevm-to-llvm` pass.
`convert-to-llvm` already uses this pattern.
---
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index b09e91f45447b..2928dc3940cb1 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -83,7 +83,6 @@ void buildGpuPassPipeline(OpPassManager &pm,
pm.addNestedPass<gpu::GPUModuleOp>(
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
}
- pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeVMToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
>From b24f09b2fcb9544b7f7f77c63ed539240b42b085 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 14 Oct 2025 15:37:04 +0000
Subject: [PATCH 12/16] Address review comments.
---
mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h | 2 +-
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 4 ++--
mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir | 4 +++-
3 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 66d42fc6a1996..fccb49d49da70 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU pipeline entry points----------------------===//
+//===- Passes.h - GPU pipeline entry points--------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 2928dc3940cb1..9aa2b3f183c09 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -50,7 +50,7 @@ void buildCommonPassPipeline(
xevmTargetOptions.cmdOptions = options.cmdOptions;
pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
}
- pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
+ pm.addPass(createLowerAffinePass());
pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
}
@@ -84,6 +84,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
}
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+ pm.addPass(createReconcileUnrealizedCastsPass());
}
//===----------------------------------------------------------------------===//
@@ -91,7 +92,6 @@ void buildGpuPassPipeline(OpPassManager &pm,
//===----------------------------------------------------------------------===//
void buildHostPostPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
- pm.addPass(createReconcileUnrealizedCastsPass());
pm.addPass(createSCFToControlFlowPass());
pm.addPass(memref::createExpandStridedMetadataPass());
{
diff --git a/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
index fdc24c02f9d98..3f2fff9ab51e9 100644
--- a/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
+++ b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
@@ -26,7 +26,9 @@ module @gemm attributes {gpu.container_module} {
gpu.memcpy %B_gpu, %B : memref<256x256xf16>, memref<256x256xf16>
%C_gpu = gpu.alloc () : memref<256x256xf32>
gpu.memcpy %C_gpu, %C : memref<256x256xf32>, memref<256x256xf32>
- // NOTE: Here we can't use [8, 64] wi threads following the SG thread layout of [8, 4]. Because runtime will linearize the x dimension first (we need y dimension to be linearized first).
+ // NOTE: Here we can't use [8, 64] wi threads following
+ // the SG thread layout of [8, 4]. Because runtime will linearize
+ // the x dimension first (we need y dimension to be linearized first).
// So just use linearized thread layout of [512, 1] wi threads.
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>)
gpu.wait // Wait for the kernel to finish.
>From 6fe5a127cfe95eb768f076d107f5e89a0ae77912 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 14 Oct 2025 21:09:24 +0000
Subject: [PATCH 13/16] Address review comments.
---
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 9aa2b3f183c09..ba70396c14728 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -84,7 +84,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
}
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
- pm.addPass(createReconcileUnrealizedCastsPass());
+ pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
}
//===----------------------------------------------------------------------===//
>From 846421ba84d0849e3ba2b64646782be4d446a6b7 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 14 Oct 2025 21:21:38 +0000
Subject: [PATCH 14/16] Add `ConvertMathToXeVMPass` to the pipeline.
---
mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 1 +
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 2 ++
2 files changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index 262248fea89ca..ec68acfee7ef1 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRGPUToNVVMTransforms
MLIRIndexToLLVM
MLIRMathToLLVM
+ MLIRMathToXeVM
MLIRNVGPUToNVVM
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index ba70396c14728..a4796888e2867 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -18,6 +18,7 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
@@ -76,6 +77,7 @@ void buildGpuPassPipeline(OpPassManager &pm,
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
}
+ pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
{
ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
>From 55e2f467805dbdf0bfd0ea60e72db2c0505eacc4 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 15 Oct 2025 00:03:47 +0000
Subject: [PATCH 15/16] Address review comments.
Change the utility functions names.
---
.../GPU/Pipelines/GPUToXeVMPipeline.cpp | 25 ++++++++++---------
1 file changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index a4796888e2867..2b996dccc2d9a 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -35,11 +35,11 @@ using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
-// Common pipeline
+// Pre-GPU common pipeline for both Host and GPU.
//===----------------------------------------------------------------------===//
-void buildCommonPassPipeline(
+void buildPreGPUCommonPassPipeline(
OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
- // builtin.module scope passes
+ // builtin.module scope passes.
pm.addPass(createCSEPass());
pm.addPass(createConvertVectorToSCFPass());
{
@@ -58,7 +58,7 @@ void buildCommonPassPipeline(
//===----------------------------------------------------------------------===//
// GPUModule-specific stuff.
//===----------------------------------------------------------------------===//
-void buildGpuPassPipeline(OpPassManager &pm,
+void buildGPUPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
if (options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
@@ -90,10 +90,11 @@ void buildGpuPassPipeline(OpPassManager &pm,
}
//===----------------------------------------------------------------------===//
-// Host Post-GPU pipeline
+// Post-GPU pipeline for both Host and GPU.
//===----------------------------------------------------------------------===//
-void buildHostPostPipeline(OpPassManager &pm,
+void buildPostGPUCommonPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+ // builtin.module scope passes.
pm.addPass(createSCFToControlFlowPass());
pm.addPass(memref::createExpandStridedMetadataPass());
{
@@ -117,14 +118,14 @@ void buildHostPostPipeline(OpPassManager &pm,
void mlir::gpu::buildLowerToXeVMPassPipeline(
OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
- // Common pipelines
- buildCommonPassPipeline(pm, options);
+ // Pre-GPU common pipelines.
+ buildPreGPUCommonPassPipeline(pm, options);
- // GPUModule-specific stuff
- buildGpuPassPipeline(pm, options);
+ // GPUModule-specific stuff.
+ buildGPUPassPipeline(pm, options);
- // Host post-GPUModule-specific stuff
- buildHostPostPipeline(pm, options);
+ // Post-GPU pipeline for both Host and GPU.
+ buildPostGPUCommonPassPipeline(pm, options);
}
void mlir::gpu::registerGPUToXeVMPipeline() {
>From be996afcfb480008168f8f0fac79f95a8bd5d2e8 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 15 Oct 2025 00:57:43 +0000
Subject: [PATCH 16/16] Fix formatting.
---
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 2b996dccc2d9a..1a1485ba2e02c 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -13,12 +13,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
@@ -92,8 +92,8 @@ void buildGPUPassPipeline(OpPassManager &pm,
//===----------------------------------------------------------------------===//
// Post-GPU pipeline for both Host and GPU.
//===----------------------------------------------------------------------===//
-void buildPostGPUCommonPassPipeline(OpPassManager &pm,
- const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+void buildPostGPUCommonPassPipeline(
+ OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
// builtin.module scope passes.
pm.addPass(createSCFToControlFlowPass());
pm.addPass(memref::createExpandStridedMetadataPass());
More information about the Mlir-commits
mailing list