[Mlir-commits] [mlir] 89d1143 - [mlir][gpu]Add GPUToXeVM lowering pipeline pass. (#161216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 17 12:41:24 PDT 2025


Author: Md Abdullah Shahneous Bari
Date: 2025-10-17T14:41:20-05:00
New Revision: 89d1143a90b7bf9b9ebabab1d8b7ed519c6ec5eb

URL: https://github.com/llvm/llvm-project/commit/89d1143a90b7bf9b9ebabab1d8b7ed519c6ec5eb
DIFF: https://github.com/llvm/llvm-project/commit/89d1143a90b7bf9b9ebabab1d8b7ed519c6ec5eb.diff

LOG: [mlir][gpu]Add GPUToXeVM lowering pipeline pass. (#161216)

It's the default GPU to XeVM lowering pipeline. It starts by lowering
GPU code to the specified compilation target (default is fatbin), then
lowers the host code.
If XeGPU ops are used, it expects the MLIR code to have XeGPU ops
already embedded in gpu code.

Added: 
    mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
    mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
    mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
    mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
    mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
    mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
    mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
    mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
    mlir/lib/RegisterAllPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
index 035235fc7174a..fccb49d49da70 100644
--- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
+//===- Passes.h - GPU pipeline entry points--------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -60,6 +60,52 @@ struct GPUToNVVMPipelineOptions
       llvm::cl::init(false)};
 };
 
+// Options for the gpu to xevm pipeline.
+struct GPUToXeVMPipelineOptions
+    : public PassPipelineOptions<GPUToXeVMPipelineOptions> {
+  PassOptions::Option<std::string> xegpuOpLevel{
+      *this, "xegpu-op-level",
+      llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
+                     "subgroup | lane"),
+      llvm::cl::init("workgroup")};
+  // General lowering controls.
+  PassOptions::Option<bool> use64bitIndex{
+      *this, "use-64bit-index",
+      llvm::cl::desc("Bitwidth of the index type (host & device)"),
+      llvm::cl::init(true)};
+  PassOptions::Option<bool> kernelBarePtrCallConv{
+      *this, "kernel-bare-ptr-calling-convention",
+      llvm::cl::desc("Use bare pointer calling convention for device kernels"),
+      llvm::cl::init(false)};
+  PassOptions::Option<bool> hostBarePtrCallConv{
+      *this, "host-bare-ptr-calling-convention",
+      llvm::cl::desc("Use bare pointer calling convention for host launches"),
+      llvm::cl::init(false)};
+  PassOptions::Option<std::string> binaryFormat{
+      *this, "binary-format",
+      llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"),
+      llvm::cl::init("fatbin")};
+  // Options mirroring xevm-attach-target (GpuXeVMAttachTarget).
+  PassOptions::Option<std::string> xevmModuleMatcher{
+      *this, "xevm-module-matcher",
+      llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"),
+      llvm::cl::init("")};
+  PassOptions::Option<std::string> zebinTriple{
+      *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"),
+      llvm::cl::init("spirv64-unknown-unknown")};
+  PassOptions::Option<std::string> zebinChip{
+      *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"),
+      llvm::cl::init("bmg")};
+  PassOptions::Option<unsigned> optLevel{
+      *this, "opt-level",
+      llvm::cl::desc("Optimization level for attached target/codegen"),
+      llvm::cl::init(2)};
+  PassOptions::Option<std::string> cmdOptions{
+      *this, "igc-cmd-options",
+      llvm::cl::desc("Additional downstream compiler command line options"),
+      llvm::cl::init("")};
+};
+
 //===----------------------------------------------------------------------===//
 // Building and Registering.
 //===----------------------------------------------------------------------===//
@@ -70,8 +116,15 @@ struct GPUToNVVMPipelineOptions
 void buildLowerToNVVMPassPipeline(OpPassManager &pm,
                                   const GPUToNVVMPipelineOptions &options);
 
-/// Register all pipeleines for the `gpu` dialect.
+/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main
+/// dialects into XeVM targets. Begins with GPU code regions, then handles host
+/// code.
+void buildLowerToXeVMPassPipeline(OpPassManager &pm,
+                                  const GPUToXeVMPipelineOptions &options);
+
+/// Register all pipelines for the `gpu` dialect.
 void registerGPUToNVVMPipeline();
+void registerGPUToXeVMPipeline();
 
 } // namespace gpu
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
index 70a9c77a6d796..ec68acfee7ef1 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRGPUPipelines
   GPUToNVVMPipeline.cpp
+  GPUToXeVMPipeline.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -11,12 +12,17 @@ add_mlir_dialect_library(MLIRGPUPipelines
   MLIRTransforms
   MLIRLinalgTransforms
   MLIRAffineToStandard
+  MLIRGPUToLLVMSPV
   MLIRGPUToNVVMTransforms
   MLIRIndexToLLVM
   MLIRMathToLLVM
+  MLIRMathToXeVM
   MLIRNVGPUToNVVM
   MLIRNVVMToLLVM
   MLIRReconcileUnrealizedCasts
   MLIRSCFToControlFlow
   MLIRVectorToSCF
+  MLIRXeGPUTransforms
+  MLIRXeGPUToXeVM
+  MLIRXeVMToLLVM
 )

diff  --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
new file mode 100644
index 0000000000000..1a1485ba2e02c
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -0,0 +1,139 @@
+//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to XeVM as a generally
+// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have
+// XeGPU ops already embedded in gpu code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+//===----------------------------------------------------------------------===//
+// Pre-GPU common pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPreGPUCommonPassPipeline(
+    OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+  // builtin.module scope passes.
+  pm.addPass(createCSEPass());
+  pm.addPass(createConvertVectorToSCFPass());
+  {
+    GpuXeVMAttachTargetOptions xevmTargetOptions;
+    xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
+    xevmTargetOptions.triple = options.zebinTriple;
+    xevmTargetOptions.chip = options.zebinChip;
+    xevmTargetOptions.optLevel = options.optLevel;
+    xevmTargetOptions.cmdOptions = options.cmdOptions;
+    pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
+  }
+  pm.addPass(createLowerAffinePass());
+  pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
+}
+
+//===----------------------------------------------------------------------===//
+// GPUModule-specific stuff.
+//===----------------------------------------------------------------------===//
+void buildGPUPassPipeline(OpPassManager &pm,
+                          const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+  if (options.xegpuOpLevel == "workgroup") {
+    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+  }
+  if (options.xegpuOpLevel == "subgroup" ||
+      options.xegpuOpLevel == "workgroup") {
+    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
+  }
+  pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToXeVM());
+  pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
+  {
+    ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
+    gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
+    pm.addNestedPass<gpu::GPUModuleOp>(
+        createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
+  }
+  pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+  pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
+}
+
+//===----------------------------------------------------------------------===//
+// Post-GPU pipeline for both Host and GPU.
+//===----------------------------------------------------------------------===//
+void buildPostGPUCommonPassPipeline(
+    OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
+  // builtin.module scope passes.
+  pm.addPass(createSCFToControlFlowPass());
+  pm.addPass(memref::createExpandStridedMetadataPass());
+  {
+    GpuToLLVMConversionPassOptions gpuToLLVMOptions;
+    gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
+    gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
+    pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
+  }
+  pm.addPass(createLowerAffinePass());
+  pm.addPass(createConvertToLLVMPass());
+  pm.addPass(createReconcileUnrealizedCastsPass());
+  // gpu-module-to-binary
+  {
+    GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
+    gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
+    gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
+    pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
+  }
+}
+} // namespace
+
+void mlir::gpu::buildLowerToXeVMPassPipeline(
+    OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
+  // Pre-GPU common pipelines.
+  buildPreGPUCommonPassPipeline(pm, options);
+
+  // GPUModule-specific stuff.
+  buildGPUPassPipeline(pm, options);
+
+  // Post-GPU pipeline for both Host and GPU.
+  buildPostGPUCommonPassPipeline(pm, options);
+}
+
+void mlir::gpu::registerGPUToXeVMPipeline() {
+  PassPipelineRegistration<GPUToXeVMPipelineOptions>(
+      "gpu-lower-to-xevm-pipeline",
+      "The default GPU to XeVM lowering pipeline. It starts by lowering GPU "
+      "code to the "
+      "specified compilation target (default is fatbin) then lowers the host "
+      "code.",
+      buildLowerToXeVMPassPipeline);
+}

diff  --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index c67b24226ae45..dd413d2de8710 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -98,4 +98,5 @@ void mlir::registerAllPasses() {
   sparse_tensor::registerSparseTensorPipelines();
   tosa::registerTosaToLinalgPipelines();
   gpu::registerGPUToNVVMPipeline();
+  gpu::registerGPUToXeVMPipeline();
 }

diff  --git a/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/LANE/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+    config.unsupported = True
+if not config.enable_levelzero_runner:
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
new file mode 100644
index 0000000000000..ffe29ef35a5c9
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
@@ -0,0 +1,121 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_levelzero_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+  gpu.module @kernel {
+    gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
+      %c0 = arith.constant 0 : index
+      %c1 = arith.constant 1 : index
+      %c8 = arith.constant 8 : index
+      %c16 = arith.constant 16 : index
+      %c32 = arith.constant 32 : index
+      %c256 = arith.constant 256 : index
+      %block_x = gpu.block_id x
+      %block_y = gpu.block_id y
+      %x_block_offset = arith.muli %block_x, %c8 : index
+      %y_block_offset = arith.muli %block_y, %c16 : index
+
+      %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+      %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+      %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16>
+      %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+
+      %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) {
+        %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+        %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+        %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+        scf.yield %dpas : vector<8xf32>
+      }
+      xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+      gpu.return
+    }
+  }
+
+  func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %memref_a = gpu.alloc  () : memref<256x256xf16>
+    gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16>
+    %memref_b = gpu.alloc  () : memref<256x256xf16>
+    gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16>
+    %memref_c = gpu.alloc  () : memref<256x256xf32>
+    gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32>
+    gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
+    gpu.wait // Wait for the kernel to finish.
+    gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32>
+    gpu.dealloc %memref_a : memref<256x256xf16>
+    gpu.dealloc %memref_b : memref<256x256xf16>
+    gpu.dealloc %memref_c : memref<256x256xf32>
+    return %c : memref<256x256xf32>
+  }
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c1_f16 = arith.constant 1.0 : f16
+    %c2_f16 = arith.constant 2.0 : f16
+    %c256 = arith.constant 256 : index
+    %cf_0 = arith.constant 0.0 : f16
+    %cf_1 = arith.constant 1.0 : f16
+    %A = memref.alloc() : memref<256x256xf16>
+    %B = memref.alloc() : memref<256x256xf16>
+    %C = memref.alloc() : memref<256x256xf32>
+    %C_ref = memref.alloc() : memref<256x256xf32>
+    %c_gen_int = arith.constant 0 : i1
+    %cf_lower = arith.constant -0.5 : f32
+    %cf_upper = arith.constant 0.5 : f32
+
+    // Initialize matrix A ; A[i, j] = j
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %t = index.castu %j : index to i16
+        %val = arith.uitofp %t : i16 to f16
+        memref.store %val, %A[%i, %j] : memref<256x256xf16>
+      }
+    }
+
+    // Initialize the B matrix.
+    // Make matrix B an identity matrix.
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %i_i32 = index.castu %i : index to i32
+        %j_i32 = index.castu %j : index to i32
+        %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+        scf.if %i_j_same {
+          memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+        } else {
+          memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+        }
+      }
+    }
+
+    // Initialize matrix C and C_ref ; C[i, j] = 0
+    %c0_f32 = arith.constant 0.0 : f32
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+        memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+      }
+    }
+
+    // Run GPU version.
+    %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+    %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+
+    // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+    // CHECK-COUNT-256: [0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,   23,   24,   25,   26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,   48,   49,   50,   51,   52,   53,   54,   55,   56,   57,   58,   59,   60,   61,   62,   63,   64,   65,   66,   67,   68,   69,   70,   71,   72,   73,   74,   75,   76,   77,   78,   79,   80,   81,   82,   83,   84,   85,   86,   87,   88,   89,   90,   91,   92,   93,   94,   95,   96,   97,   98,   99,   100,   101,   102,   103,   104,   105,   106,   107,   108,   109,   110,   111,   112,   113,   114,   115,   116,   117,   118,   119,   120,   121,   122,   123,   124,   125,   126,   127,   128,   129,   130,   131,   132,   133,   134,   135,   136,   137,   138,   139,   140,   141,   142,   143,   144,   145,   146,   147,   148,   149,   150,   151,   152,   153,   154,   155,   156,   157,   158,   159,   160,   161,   162,   163,   164,   165,   166,   167,   168,   169,   170,   171,   172,   173,   174,   175,   176,   177,   178,   179,   180,   181,   182,   183,   184,   185,   186,   187,   188,   189,   190,   191,   192,   193,   194,   195,   196,   197,   198,   199,   200,   201,   202,   203,   204,   205,   206,   207,   208,   209,   210,   211,   212,   213,   214,   215,   216,   217,   218,   219,   220,   221,   222,   223,   224,   225,   226,   227,   228,   229,   230,   231,   232,   233,   234,   235,   236,   237,   238,   239,   240,   241,   242,   243,   244,   245,   246,   247,   248,   249,   250,   251,   252,   253,   254,   255]
+    call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> ()
+    memref.dealloc %A : memref<256x256xf16>
+    memref.dealloc %B : memref<256x256xf16>
+    memref.dealloc %C : memref<256x256xf32>
+    memref.dealloc %C_ref : memref<256x256xf32>
+    return
+  }
+  func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}

diff  --git a/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/SG/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+    config.unsupported = True
+if not config.enable_levelzero_runner:
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
new file mode 100644
index 0000000000000..877edf47fcd15
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir
@@ -0,0 +1,120 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=subgroup" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_levelzero_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @gemm attributes {gpu.container_module} {
+  gpu.module @kernel {
+    gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
+      %c0 = arith.constant 0 : index
+      %c1 = arith.constant 1 : index
+      %c8 = arith.constant 8 : index
+      %c16 = arith.constant 16 : index
+      %c32 = arith.constant 32 : index
+      %c256 = arith.constant 256 : index
+      %block_x = gpu.block_id x
+      %block_y = gpu.block_id y
+      %x_block_offset = arith.muli %block_x, %c8 : index
+      %y_block_offset = arith.muli %block_y, %c16 : index
+
+      %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
+      %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+      %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16>
+      %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+
+      %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8x16xf32>) {
+        %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k]  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+        %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+        %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+        scf.yield %dpas : vector<8x16xf32>
+      }
+      xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+      gpu.return
+    }
+  }
+
+  func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %memref_a = gpu.alloc  () : memref<256x256xf16>
+    gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16>
+    %memref_b = gpu.alloc () : memref<256x256xf16>
+    gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16>
+    %memref_c = gpu.alloc  () : memref<256x256xf32>
+    gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32>
+    gpu.launch_func  @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
+    gpu.wait  // Wait for the kernel to finish.
+    gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32>
+    gpu.dealloc %memref_a : memref<256x256xf16>
+    gpu.dealloc %memref_b : memref<256x256xf16>
+    gpu.dealloc %memref_c : memref<256x256xf32>
+    return %c : memref<256x256xf32>
+  }
+
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c1_f16 = arith.constant 1.0 : f16
+    %c2_f16 = arith.constant 2.0 : f16
+    %c256 = arith.constant 256 : index
+    %cf_0 = arith.constant 0.0 : f16
+    %cf_1 = arith.constant 1.0 : f16
+    %A = memref.alloc() : memref<256x256xf16>
+    %B = memref.alloc() : memref<256x256xf16>
+    %C = memref.alloc() : memref<256x256xf32>
+    %C_ref = memref.alloc() : memref<256x256xf32>
+    %c_gen_int = arith.constant 0 : i1
+    %cf_lower = arith.constant -0.5 : f32
+    %cf_upper = arith.constant 0.5 : f32
+    // Option 1: intialize matrix A ; A[i, j] = j
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %t = index.castu %j : index to i16
+        %val = arith.uitofp %t : i16 to f16
+        memref.store %val, %A[%i, %j] : memref<256x256xf16>
+      }
+    }
+
+    // Initialize the B matrix
+    // Make matrix B an identity matrix
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %i_i32 = index.castu %i : index to i32
+        %j_i32 = index.castu %j : index to i32
+        %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+        scf.if %i_j_same {
+          memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+        } else {
+          memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+        }
+      }
+    }
+    // intialize matrix C and C_ref ; C[i, j] = 0
+    %c0_f32 = arith.constant 0.0 : f32
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+        memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+      }
+    }
+
+    // Run GPU.
+    %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+    %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+    // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+    // CHECK-COUNT-256: [0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,   23,   24,   25,   26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,   48,   49,   50,   51,   52,   53,   54,   55,   56,   57,   58,   59,   60,   61,   62,   63,   64,   65,   66,   67,   68,   69,   70,   71,   72,   73,   74,   75,   76,   77,   78,   79,   80,   81,   82,   83,   84,   85,   86,   87,   88,   89,   90,   91,   92,   93,   94,   95,   96,   97,   98,   99,   100,   101,   102,   103,   104,   105,   106,   107,   108,   109,   110,   111,   112,   113,   114,   115,   116,   117,   118,   119,   120,   121,   122,   123,   124,   125,   126,   127,   128,   129,   130,   131,   132,   133,   134,   135,   136,   137,   138,   139,   140,   141,   142,   143,   144,   145,   146,   147,   148,   149,   150,   151,   152,   153,   154,   155,   156,   157,   158,   159,   160,   161,   162,   163,   164,   165,   166,   167,   168,   169,   170,   171,   172,   173,   174,   175,   176,   177,   178,   179,   180,   181,   182,   183,   184,   185,   186,   187,   188,   189,   190,   191,   192,   193,   194,   195,   196,   197,   198,   199,   200,   201,   202,   203,   204,   205,   206,   207,   208,   209,   210,   211,   212,   213,   214,   215,   216,   217,   218,   219,   220,   221,   222,   223,   224,   225,   226,   227,   228,   229,   230,   231,   232,   233,   234,   235,   236,   237,   238,   239,   240,   241,   242,   243,   244,   245,   246,   247,   248,   249,   250,   251,   252,   253,   254,   255]
+    call @printMemrefF32(%cast_C) : (memref<*xf32>) -> ()
+
+    memref.dealloc %A : memref<256x256xf16>
+    memref.dealloc %B : memref<256x256xf16>
+    memref.dealloc %C : memref<256x256xf32>
+    memref.dealloc %C_ref : memref<256x256xf32>
+    return
+  }
+  func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}

diff  --git a/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg b/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
new file mode 100644
index 0000000000000..d0d51c6020588
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/WG/lit.local.cfg
@@ -0,0 +1,4 @@
+if not config.run_xevm_tests:
+    config.unsupported = True
+if not config.enable_levelzero_runner:
+    config.unsupported = True

diff  --git a/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
new file mode 100644
index 0000000000000..3f2fff9ab51e9
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir
@@ -0,0 +1,151 @@
+// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workgroup" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_levelzero_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#a = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+#b = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [16, 16]>
+#c = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>
+#a_prefetch = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32], inst_data = [8, 16]>
+#b_prefetch = #xegpu.layout<sg_layout = [4, 8], sg_data = [8, 32], inst_data = [8, 16]>
+module @gemm attributes {gpu.container_module} {
+  func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c8 = arith.constant 8 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %c512 = arith.constant 512 : index
+    %A_gpu = gpu.alloc () : memref<256x256xf16>
+    gpu.memcpy %A_gpu, %A : memref<256x256xf16>, memref<256x256xf16>
+    %B_gpu = gpu.alloc () : memref<256x256xf16>
+    gpu.memcpy %B_gpu, %B : memref<256x256xf16>, memref<256x256xf16>
+    %C_gpu = gpu.alloc () : memref<256x256xf32>
+    gpu.memcpy %C_gpu, %C : memref<256x256xf32>, memref<256x256xf32>
+    // NOTE: Here we can't use [8, 64] wi threads following
+    // the SG thread layout of [8, 4]. Because runtime will linearize
+    // the x dimension first (we need y dimension to be linearized first).
+    // So just use linearized thread layout of [512, 1] wi threads.
+    gpu.launch_func  @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>)
+    gpu.wait // Wait for the kernel to finish.
+    gpu.memcpy %C, %C_gpu : memref<256x256xf32>, memref<256x256xf32>
+    gpu.dealloc %A_gpu : memref<256x256xf16>
+    gpu.dealloc %B_gpu : memref<256x256xf16>
+    gpu.dealloc %C_gpu : memref<256x256xf32>
+    return %C : memref<256x256xf32>
+  }
+
+  gpu.module @test_kernel   {
+    gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel  {
+      %c0 = arith.constant 0 : index
+      %c1 = arith.constant 1 : index
+      %c32 = arith.constant 32 : index
+      %c64 = arith.constant 64 : index
+      %c96 = arith.constant 96 : index
+      %c256 = arith.constant 256 : index
+      %c4096 = arith.constant 4096 : index
+      %block_id_x = gpu.block_id x
+      %block_id_y = gpu.block_id y
+      %m = arith.muli %block_id_x, %c256 : index
+      %n = arith.muli %block_id_y, %c256 : index
+      %c_tdesc = xegpu.create_nd_tdesc %C : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #c>
+      %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32>
+      %a_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a>
+      %b_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b>
+      // Prefetch A 3 times.
+      %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+      xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+      xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+      xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+       // Prefetch B 3 times.
+      %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+      xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+      xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+      xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+
+      %out = scf.for %k = %c0 to %c256 step %c32
+        iter_args(%c_value = %c_init_value)
+        -> (vector<256x256xf32>) {
+        %a_value = xegpu.load_nd %a_tdesc[%m, %k]  : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16>
+        %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16>
+        // Prefetch next tiles.
+        %prefetch_offset = arith.addi %k, %c96 : index
+        xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch>
+        xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch>
+        %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c}
+          : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32>
+        scf.yield %c_new_value : vector<256x256xf32>
+      }
+      xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c>
+      gpu.return
+    }
+  }
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c1_f16 = arith.constant 1.0 : f16
+    %c2_f16 = arith.constant 2.0 : f16
+    %c256 = arith.constant 256 : index
+    %cf_0 = arith.constant 0.0 : f16
+    %cf_1 = arith.constant 1.0 : f16
+    %A = memref.alloc() : memref<256x256xf16>
+    %B = memref.alloc() : memref<256x256xf16>
+    %C = memref.alloc() : memref<256x256xf32>
+    %C_ref = memref.alloc() : memref<256x256xf32>
+    %c_gen_int = arith.constant 0 : i1
+    %cf_lower = arith.constant -0.5 : f32
+    %cf_upper = arith.constant 0.5 : f32
+    // Intialize matrix A ; A[i, j] = j
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %t = index.castu %j : index to i16
+        %val = arith.uitofp %t : i16 to f16
+        memref.store %val, %A[%i, %j] : memref<256x256xf16>
+      }
+    }
+
+    // Initialize the B matrix
+    // Make matrix B an identity matrix
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        %i_i32 = index.castu %i : index to i32
+        %j_i32 = index.castu %j : index to i32
+        %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
+
+        scf.if %i_j_same {
+          memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
+        } else {
+          memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
+        }
+      }
+    }
+
+    // Initialize matrix C and C_ref ; C[i, j] = 0
+    %c0_f32 = arith.constant 0.0 : f32
+    scf.for %i = %c0 to %c256 step %c1 {
+      scf.for %j = %c0 to %c256 step %c1 {
+        memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
+        memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
+      }
+    }
+
+    // Run GPU version.
+    %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
+    %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
+    // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
+    // CHECK-COUNT-256: [0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,   11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,   23,   24,   25,   26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,   48,   49,   50,   51,   52,   53,   54,   55,   56,   57,   58,   59,   60,   61,   62,   63,   64,   65,   66,   67,   68,   69,   70,   71,   72,   73,   74,   75,   76,   77,   78,   79,   80,   81,   82,   83,   84,   85,   86,   87,   88,   89,   90,   91,   92,   93,   94,   95,   96,   97,   98,   99,   100,   101,   102,   103,   104,   105,   106,   107,   108,   109,   110,   111,   112,   113,   114,   115,   116,   117,   118,   119,   120,   121,   122,   123,   124,   125,   126,   127,   128,   129,   130,   131,   132,   133,   134,   135,   136,   137,   138,   139,   140,   141,   142,   143,   144,   145,   146,   147,   148,   149,   150,   151,   152,   153,   154,   155,   156,   157,   158,   159,   160,   161,   162,   163,   164,   165,   166,   167,   168,   169,   170,   171,   172,   173,   174,   175,   176,   177,   178,   179,   180,   181,   182,   183,   184,   185,   186,   187,   188,   189,   190,   191,   192,   193,   194,   195,   196,   197,   198,   199,   200,   201,   202,   203,   204,   205,   206,   207,   208,   209,   210,   211,   212,   213,   214,   215,   216,   217,   218,   219,   220,   221,   222,   223,   224,   225,   226,   227,   228,   229,   230,   231,   232,   233,   234,   235,   236,   237,   238,   239,   240,   241,   242,   243,   244,   245,   246,   247,   248,   249,   250,   251,   252,   253,   254,   255]
+    call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> ()
+
+    memref.dealloc %A : memref<256x256xf16>
+    memref.dealloc %B : memref<256x256xf16>
+    memref.dealloc %C : memref<256x256xf32>
+    memref.dealloc %C_ref : memref<256x256xf32>
+    return
+  }
+  func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
+}


        


More information about the Mlir-commits mailing list