[Mlir-commits] [mlir] [MLIR][Conversion] Add convert-xevm-to-llvm pass. (PR #147375)

Sang Ik Lee llvmlistbot at llvm.org
Thu Jul 10 07:29:25 PDT 2025


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/147375

>From d9ff9e3a45d66dcc49b72223ae44a23b07f831a9 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 7 Jul 2025 18:49:14 +0000
Subject: [PATCH 1/5] Add convert-xevm-to-llvm pass.

Co-authored-by: Artem Kroviakov artem.kroviakov at intel.com
---
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |   9 +
 .../mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h   |  29 +
 mlir/include/mlir/InitAllExtensions.h         |   2 +
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt |  21 +
 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 669 ++++++++++++++++++
 .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir   |  83 +++
 mlir/test/lib/Dialect/GPU/CMakeLists.txt      |   1 +
 9 files changed, 816 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
 create mode 100644 mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
 create mode 100644 mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index c9d2a54433736..8a5976e547169 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -80,6 +80,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 
 namespace mlir {
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..929c3cf5a25ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1495,4 +1495,13 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// XeVMToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
+  let summary = "Convert XeVM to LLVM dialect";
+  let dependentDialects = ["xevm::XeVMDialect", ];
+}
+
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
new file mode 100644
index 0000000000000..b361af573afff
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -0,0 +1,29 @@
+//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+} // namespace mlir
+
+namespace mlir {
+#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+
+void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..47e2f554abb69 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@@ -90,6 +91,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
+  registerConvertXeVMToLLVMInterface(registry);
 
   // Register all transform dialect extensions.
   affine::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e4b4974600577..24a48993ad80c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -73,3 +73,4 @@ add_subdirectory(VectorToLLVM)
 add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
 add_subdirectory(VectorToXeGPU)
+add_subdirectory(XeVMToLLVM)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
new file mode 100644
index 0000000000000..4ac60d8d43472
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRXeVMToLLVM
+  XeVMToLLVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRXeVMDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
new file mode 100644
index 0000000000000..89407825fd656
--- /dev/null
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -0,0 +1,669 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "xevm-to-llvm"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace xevm;
+
+namespace {
+
+struct LLVMFuncAttributeOptions {
+  bool isConvergent = false;
+  bool isNoUnwind = false;
+  bool isWillReturn = false;
+  LLVM::MemoryEffectsAttr memEffectsAttr{};
+};
+static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
+    false, true, false, {}};
+static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
+    false, true, true, {}};
+static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
+    true, true, true, {}};
+
+std::string getTypeMangling(Type ty, bool isUnsigned = false) {
+  return TypeSwitch<Type, std::string>(ty)
+      .Case([isUnsigned](VectorType ty) -> std::string {
+        return "Dv" + std::to_string(ty.getNumElements()) + "_" +
+               getTypeMangling(ty.getElementType(), isUnsigned);
+      })
+      .Case([](Float16Type) -> std::string { return "Dh"; })
+      .Case([](Float32Type) -> std::string { return "f"; })
+      .Case([](Float64Type) -> std::string { return "d"; })
+      .Case([isUnsigned](IntegerType ty) -> std::string {
+        switch (ty.getWidth()) {
+        case 8:
+          return isUnsigned ? "h" : "c";
+        case 16:
+          return isUnsigned ? "t" : "s";
+        case 32:
+          return isUnsigned ? "j" : "i";
+        case 64:
+          return isUnsigned ? "m" : "l";
+        default:
+          llvm_unreachable("unhandled integer type");
+        }
+      });
+}
+
+std::string mangle(StringRef baseName, ArrayRef<Type> types,
+                   ArrayRef<bool> isUnsigned = {}) {
+  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
+         "Signedness info doesn't match");
+  std::string s;
+  llvm::raw_string_ostream os(s);
+  llvm::SmallDenseMap<Type, unsigned> substitutions;
+  os << "_Z" << baseName.size() << baseName;
+  for (auto [idx, type] : llvm::enumerate(types)) {
+    auto it = substitutions.find(type);
+    if (it != substitutions.end()) {
+      os << "S";
+      // First substitution is `S_`, second is `S0_`, and so on.
+      if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
+        os << firstIdx - 1;
+      os << "_";
+    } else {
+      if (!type.isIntOrFloat())
+        substitutions[type] = substitutions.size();
+      os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
+    }
+  }
+  return os.str();
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL1CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3UC:
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3UC:
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+    case LoadCacheControl::L1S_L2UC_L3C:
+    case LoadCacheControl::L1S_L2C_L3UC:
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 3;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3UC:
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 3;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+int32_t getL3CacheControl(OpType op) {
+  int32_t control = 0;
+  if constexpr (isLoad) {
+    switch (*op.getCacheControl()) {
+    case LoadCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1UC_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1C_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1C_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2UC_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::L1S_L2C_L3UC:
+      control = 1;
+      break;
+    case LoadCacheControl::L1S_L2C_L3C:
+      control = 2;
+      break;
+    case LoadCacheControl::INVALIDATE_READ:
+      control = 4;
+      break;
+    default:
+      break;
+    }
+  } else {
+    switch (*op.getCacheControl()) {
+    case StoreCacheControl::L1UC_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2UC_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1S_L2WB_L3WB:
+      control = 2;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2WB_L3UC:
+      control = 1;
+      break;
+    case StoreCacheControl::L1WB_L2UC_L3WB:
+      control = 2;
+      break;
+    default:
+      break;
+    }
+  }
+  return control;
+}
+
+template <bool isLoad, typename OpType>
+static std::optional<ArrayAttr>
+getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
+  if constexpr (isLoad) {
+    if (!op.getCacheControl())
+      return {};
+  } else {
+    if (!op.getCacheControl())
+      return {};
+  }
+  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t loadCacheControlKey{6442};
+  constexpr int32_t storeCacheControlKey{6443};
+  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
+      controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+  SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
+      controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
+  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
+
+  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
+  return rewriter.getArrayAttr(combinedAttrs);
+}
+
+static LLVM::CallOp createDeviceFunctionCall(
+    ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+    ArrayRef<Type> argTypes, ArrayRef<Value> args,
+    mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
+    LLVMFuncAttributeOptions funcAttributeOptions) {
+  auto moduleOp = rewriter.getBlock()
+                      ->getParentOp()
+                      ->getParentWithTrait<OpTrait::SymbolTable>();
+  assert(moduleOp && "Expecting module");
+  MLIRContext *ctx = rewriter.getContext();
+  Location loc = UnknownLoc::get(ctx);
+
+  auto funcOpRes =
+      LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
+  assert(!failed(funcOpRes));
+  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+  funcOp.setConvergent(funcAttributeOptions.isConvergent);
+  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
+  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
+
+  if (funcAttributeOptions.memEffectsAttr)
+    funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
+
+  for (auto [idx, attrName] : paramAttrs)
+    funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
+
+  auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+  callOp->setAttrs(funcOp->getAttrs());
+
+  return callOp;
+}
+
+class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getC()) {
+      return rewriter.notifyMatchFailure(op, "OCL requires C operand");
+    }
+    constexpr uint32_t bitWidthPackedA{16};
+    constexpr uint32_t bitWidthPackedB{32};
+    auto loc = op.getLoc();
+
+    auto castIfNeeded = [&](Value val, Type packedType) -> Value {
+      VectorType origTy = cast<VectorType>(val.getType());
+      const uint32_t vecBitSize =
+          origTy.getNumElements() *
+          origTy.getElementType().getIntOrFloatBitWidth();
+      VectorType newTy = VectorType::get(
+          vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
+      if (origTy != newTy)
+        val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+      return val;
+    };
+
+    Value a = op.getA();
+    Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedA);
+    a = castIfNeeded(a, packedAType);
+
+    Value b = op.getB();
+    Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
+                           ? cast<Type>(rewriter.getF32Type())
+                           : rewriter.getIntegerType(bitWidthPackedB);
+    b = castIfNeeded(b, packedBType);
+
+    Value c = op.getC();
+    VectorType cOrigTy = cast<VectorType>(c.getType());
+    assert(cOrigTy == op->getResultTypes()[0] &&
+           "Accumulator and result type mismatch");
+    // OCL builtins encode bfloat16 as int16
+    VectorType cTy =
+        cOrigTy.getElementType().isBF16()
+            ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
+            : cOrigTy;
+    if (cOrigTy != cTy)
+      c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+
+    constexpr int32_t systolicDepth{8};
+    std::string fnName =
+        llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
+                      stringifyElemType(op.getTypes().getA()).str(),
+                      stringifyElemType(op.getTypes().getB()).str(),
+                      systolicDepth *
+                          getNumOperandsPerDword(op.getTypes().getA()))
+            .str();
+    SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
+    fnName = mangle(fnName, argTypes);
+    SmallVector<Value> args{a, b, c};
+
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::NoModRef,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    auto funcAttrs = convergentNoUnwindWillReturnAttrs;
+    funcAttrs.memEffectsAttr = memAttr;
+    Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
+                                            args, {}, funcAttrs)
+                       ->getResult(0);
+
+    if (cOrigTy != cTy)
+      result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+    switch (pTy) {
+    case xevm::ElemType::TF32:
+      return 1;
+    case xevm::ElemType::BF16:
+    case xevm::ElemType::F16:
+      return 2;
+    case xevm::ElemType::U8:
+    case xevm::ElemType::S8:
+      return 4;
+    default:
+      llvm_unreachable("unsupported xevm::ElemType");
+    }
+  }
+};
+
+class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+    SmallVector<Value> args{op.getPtr(), one};
+    SmallVector<Type> argTypes;
+    for (auto arg : args)
+      argTypes.push_back(arg.getType());
+    auto funcAttr = noUnwindAttrs;
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/LLVM::ModRefInfo::NoModRef,
+        /*argMem=*/LLVM::ModRefInfo::Ref,
+        /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+    funcAttr.memEffectsAttr = memAttr;
+
+    LLVM::CallOp call = createDeviceFunctionCall(
+        rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+        argTypes, args, {}, funcAttr);
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata<true>(rewriter, op))
+      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    const std::string fnName{"atomic_work_item_fence"};
+    int memScope, addrSpace;
+    switch (op.getAddrspace()) {
+    case xevm::AddrSpace::SHARED:
+      addrSpace = 1; // CLK_LOCAL_MEM_FENCE
+      break;
+    case xevm::AddrSpace::GLOBAL:
+      addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
+      break;
+    default:
+      // GENERIC is not supported in OpenCL
+      llvm_unreachable("Fence only supports global and shared address spaces.");
+    }
+    switch (op.getScope()) {
+    case xevm::MemScope::WORKGROUP:
+      memScope = 1;
+      break;
+    case xevm::MemScope::DEVICE:
+      memScope = 2;
+      break;
+    default:
+      // CLUSTER and SYSTEM are not supported in OpenCL
+      llvm_unreachable("unsupported xevm::MemoryScope");
+    }
+    Value acqRel = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(4));
+    Value memScopeConst = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope));
+    Value addrSpaceConst = rewriter.create<LLVM::ConstantOp>(
+        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace));
+    SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
+    SmallVector<Type> argTypes{3, rewriter.getI32Type()};
+    createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
+                             LLVM::LLVMVoidType::get(rewriter.getContext()),
+                             argTypes, args, {}, noUnwindAttrs);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+template <typename OpType>
+class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
+    constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
+
+    auto loc = op.getLoc();
+    VectorType vecType;
+    bool packReg = false;
+    bool transpose = false;
+    if constexpr (isLoad) {
+      vecType = op.getRes().getType();
+      packReg = op.getPackRegister();
+      transpose = op.getTranspose();
+    } else if constexpr (!isPrefetch) {
+      vecType = op.getStoredVal().getType();
+    }
+
+    auto i32Type = rewriter.getI32Type();
+    Value byteCoord =
+        rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getI32IntegerAttr(0));
+    Value one = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getI32IntegerAttr(1));
+    byteCoord = rewriter.create<LLVM::InsertElementOp>(
+        loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
+    byteCoord = rewriter.create<LLVM::InsertElementOp>(
+        loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
+    SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
+                            op.getBasePitch(), byteCoord};
+    SmallVector<Type> retTypes;
+    Value spvLoadDstPtr;
+    std::string funcName{"intel_sub_group_2d_block_"};
+    std::string bitWidthId;
+    LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
+    SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
+    if constexpr (isPrefetch) { // Prefetch
+      funcName += "prefetch";
+      paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
+      auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+          /*other=*/LLVM::ModRefInfo::NoModRef,
+          /*argMem=*/LLVM::ModRefInfo::Ref,
+          /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
+      funcAttr = noUnwindAttrs;
+      funcAttr.memEffectsAttr = memAttr;
+    } else {
+      auto vecElemType = vecType.getElementType();
+      auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
+      Value numElems = rewriter.create<LLVM::ConstantOp>(
+          loc, i32Type, vecType.getNumElements());
+      auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
+          loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
+          numElems);
+      args.push_back(dstOrSrcPtr);
+      if constexpr (isLoad) { // Load
+        funcName += "read";
+        bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
+        if (packReg)
+          funcName += "_transform";
+        else if (transpose)
+          funcName += "_transpose";
+        spvLoadDstPtr = dstOrSrcPtr;
+        retTypes.push_back(vecType);
+        paramAttrs = {
+            std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
+            std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
+            std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
+            std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
+        };
+      } else { // Store
+        funcName += "write";
+        bitWidthId = (vecElemBitWidth == 32)
+                         ? "j"
+                         : ((vecElemBitWidth == 16) ? "t" : "h");
+        rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
+        paramAttrs = {
+            std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
+            std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
+            std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
+            std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
+        };
+      }
+    }
+
+    funcName =
+        llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
+                      op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
+            .str();
+    funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
+                             funcName, isPrefetch ? "" : "P", bitWidthId)
+                   .str();
+    SmallVector<Type> argTypes;
+    for (auto arg : args) {
+      argTypes.push_back(arg.getType());
+    }
+    LLVM::CallOp call = createDeviceFunctionCall(
+        rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
+        argTypes, args, paramAttrs, funcAttr);
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
+      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    }
+    if constexpr (isLoad)
+      rewriter.replaceOp(
+          op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
+    else
+      rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeVMToLLVMPass
+    : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect, XeVMDialect>();
+  }
+
+  void runOnOperation() override {
+    ConversionTarget target(getContext());
+    target.addLegalDialect<LLVM::LLVMDialect>();
+    target.addIllegalDialect<XeVMDialect>();
+    RewritePatternSet patterns(&getContext());
+    populateXeVMToLLVMConversionPatterns(patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// ConvertToLLVMPatternInterface implementation
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Implement the interface to convert XeVM to LLVM.
+struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+  void loadDependentDialects(MLIRContext *context) const final {
+    context->loadDialect<LLVM::LLVMDialect>();
+  }
+
+  /// Hook for derived dialect interface to provide conversion patterns
+  /// and mark dialect legal for the conversion target.
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateXeVMToLLVMConversionPatterns(patterns);
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
+               LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
+               LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
+               MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
+      patterns.getContext());
+}
+
+void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
+    dialect->addInterfaces<XeVMToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
new file mode 100644
index 0000000000000..2d9a9d5683756
--- /dev/null
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
+
+// Same below, but using the `ConvertToLLVMPatternInterface` entry point
+// and the generic `convert-to-llvm` pass.
+// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
+
+// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
+// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -> vector<8xi16> {
+llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
+  // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
+  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  llvm.return %loaded_a : vector<8xi16>
+}
+
+// -----
+// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return}
+// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) {
+llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
+  // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr
+  // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) -> ()
+  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  llvm.return
+}
+
+// -----
+// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
+llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
+  // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
+  // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
+  xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+  llvm.return
+}
+
+// -----
+// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
+// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> {
+llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
+  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  llvm.return %c_result : vector<8xf32>
+}
+
+// -----
+// CHECK: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind}
+llvm.func @memfence() {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) {function_type = !llvm.func<void (i32, i32, i32)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> ()
+  xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
+  llvm.return
+}
+
+// -----
+// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) {
+llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
+  xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  llvm.return
+}
+
diff --git a/mlir/test/lib/Dialect/GPU/CMakeLists.txt b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
index 4ca5974ed5a49..418c884dc03b3 100644
--- a/mlir/test/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/GPU/CMakeLists.txt
@@ -29,6 +29,7 @@ set(LIBS
   MLIRTranslateLib
   MLIRVectorDialect
   MLIRVectorToLLVMPass
+  MLIRXeVMDialect
   )
 
 add_mlir_library(MLIRGPUTestPasses

>From 357377bd075d42dfac6b7d4618349e496d355546 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 9 Jul 2025 03:54:57 +0000
Subject: [PATCH 2/5] Address reviewer comments.

---
 mlir/include/mlir/Conversion/Passes.td        |   2 +-
 .../mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h   |   2 -
 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 125 ++++++------------
 .../Conversion/XeVMToLLVM/xevm-to-llvm.mlir   |  86 +++++++++---
 4 files changed, 106 insertions(+), 109 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 929c3cf5a25ed..78e68632409d8 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1501,7 +1501,7 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
 
 def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
   let summary = "Convert XeVM to LLVM dialect";
-  let dependentDialects = ["xevm::XeVMDialect", ];
+  let dependentDialects = ["LLVM::LLVMDialect", ];
 }
 
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
index b361af573afff..7ffdbd4307f9e 100644
--- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -15,9 +15,7 @@ class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
 class Pass;
-} // namespace mlir
 
-namespace mlir {
 #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
 
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 89407825fd656..4605ef78ee50d 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -23,8 +23,6 @@
 
 #include "llvm/ADT/TypeSwitch.h"
 
-#define DEBUG_TYPE "xevm-to-llvm"
-
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -70,6 +68,9 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
         default:
           llvm_unreachable("unhandled integer type");
         }
+      })
+      .Default([](Type) -> std::string {
+        llvm_unreachable("unhandled type for mangling");
       });
 }
 
@@ -165,38 +166,18 @@ int32_t getL3CacheControl(OpType op) {
   if constexpr (isLoad) {
     switch (*op.getCacheControl()) {
     case LoadCacheControl::L1UC_L2UC_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1UC_L2UC_L3C:
-      control = 2;
-      break;
     case LoadCacheControl::L1UC_L2C_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1UC_L2C_L3C:
-      control = 2;
-      break;
     case LoadCacheControl::L1C_L2UC_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1C_L2UC_L3C:
-      control = 2;
-      break;
     case LoadCacheControl::L1C_L2C_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1C_L2C_L3C:
-      control = 2;
-      break;
     case LoadCacheControl::L1S_L2UC_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1S_L2UC_L3C:
-      control = 2;
-      break;
     case LoadCacheControl::L1S_L2C_L3UC:
       control = 1;
       break;
+    case LoadCacheControl::L1UC_L2UC_L3C:
+    case LoadCacheControl::L1UC_L2C_L3C:
+    case LoadCacheControl::L1C_L2UC_L3C:
+    case LoadCacheControl::L1C_L2C_L3C:
+    case LoadCacheControl::L1S_L2UC_L3C:
     case LoadCacheControl::L1S_L2C_L3C:
       control = 2;
       break;
@@ -209,47 +190,21 @@ int32_t getL3CacheControl(OpType op) {
   } else {
     switch (*op.getCacheControl()) {
     case StoreCacheControl::L1UC_L2UC_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1UC_L2UC_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1UC_L2WB_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1UC_L2WB_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1WT_L2UC_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1WT_L2UC_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1WT_L2WB_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1WT_L2WB_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1S_L2UC_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1S_L2UC_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1S_L2WB_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1S_L2WB_L3WB:
-      control = 2;
-      break;
     case StoreCacheControl::L1WB_L2UC_L3UC:
-      control = 1;
-      break;
     case StoreCacheControl::L1WB_L2WB_L3UC:
       control = 1;
       break;
+    case StoreCacheControl::L1UC_L2UC_L3WB:
+    case StoreCacheControl::L1UC_L2WB_L3WB:
+    case StoreCacheControl::L1WT_L2UC_L3WB:
+    case StoreCacheControl::L1WT_L2WB_L3WB:
+    case StoreCacheControl::L1S_L2UC_L3WB:
+    case StoreCacheControl::L1S_L2WB_L3WB:
     case StoreCacheControl::L1WB_L2UC_L3WB:
       control = 2;
       break;
@@ -263,13 +218,8 @@ int32_t getL3CacheControl(OpType op) {
 template <bool isLoad, typename OpType>
 static std::optional<ArrayAttr>
 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
-  if constexpr (isLoad) {
-    if (!op.getCacheControl())
-      return {};
-  } else {
-    if (!op.getCacheControl())
-      return {};
-  }
+  if (!op.getCacheControl())
+    return {};
   constexpr int32_t decorationCacheControlArity{4};
   constexpr int32_t loadCacheControlKey{6442};
   constexpr int32_t storeCacheControlKey{6443};
@@ -289,13 +239,12 @@ static LLVM::CallOp createDeviceFunctionCall(
     ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
     ArrayRef<Type> argTypes, ArrayRef<Value> args,
     mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
-    LLVMFuncAttributeOptions funcAttributeOptions) {
+    LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
   auto moduleOp = rewriter.getBlock()
                       ->getParentOp()
                       ->getParentWithTrait<OpTrait::SymbolTable>();
   assert(moduleOp && "Expecting module");
-  MLIRContext *ctx = rewriter.getContext();
-  Location loc = UnknownLoc::get(ctx);
+  Location loc = op->getLoc();
 
   auto funcOpRes =
       LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
@@ -384,9 +333,10 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
         /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
     auto funcAttrs = convergentNoUnwindWillReturnAttrs;
     funcAttrs.memEffectsAttr = memAttr;
-    Value result = createDeviceFunctionCall(rewriter, fnName, cTy, argTypes,
-                                            args, {}, funcAttrs)
-                       ->getResult(0);
+    Value result =
+        createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {},
+                                 funcAttrs, op.getOperation())
+            ->getResult(0);
 
     if (cOrigTy != cTy)
       result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
@@ -419,8 +369,8 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
-    Value one = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1));
+    Value one =
+        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
     SmallVector<Value> args{op.getPtr(), one};
     SmallVector<Type> argTypes;
     for (auto arg : args)
@@ -434,7 +384,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
 
     LLVM::CallOp call = createDeviceFunctionCall(
         rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
-        argTypes, args, {}, funcAttr);
+        argTypes, args, {}, funcAttr, op.getOperation());
     if (std::optional<ArrayAttr> optCacheControls =
             getCacheControlMetadata<true>(rewriter, op))
       call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
@@ -473,17 +423,18 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
       // CLUSTER and SYSTEM are not supported in OpenCL
       llvm_unreachable("unsupported xevm::MemoryScope");
     }
-    Value acqRel = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(4));
-    Value memScopeConst = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(memScope));
-    Value addrSpaceConst = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(addrSpace));
+    Type i32Type = rewriter.getI32Type();
+    Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
+    Value memScopeConst =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
+    Value addrSpaceConst =
+        rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
     SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
-    SmallVector<Type> argTypes{3, rewriter.getI32Type()};
+    SmallVector<Type> argTypes{3, i32Type};
     createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
                              LLVM::LLVMVoidType::get(rewriter.getContext()),
-                             argTypes, args, {}, noUnwindAttrs);
+                             argTypes, args, {}, noUnwindAttrs,
+                             op.getOperation());
     rewriter.eraseOp(op);
     return success();
   }
@@ -512,10 +463,8 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
     auto i32Type = rewriter.getI32Type();
     Value byteCoord =
         rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
-    Value zero = rewriter.create<LLVM::ConstantOp>(
-        loc, i32Type, rewriter.getI32IntegerAttr(0));
-    Value one = rewriter.create<LLVM::ConstantOp>(
-        loc, i32Type, rewriter.getI32IntegerAttr(1));
+    Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
+    Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
     byteCoord = rewriter.create<LLVM::InsertElementOp>(
         loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
     byteCoord = rewriter.create<LLVM::InsertElementOp>(
@@ -589,7 +538,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
     }
     LLVM::CallOp call = createDeviceFunctionCall(
         rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
-        argTypes, args, paramAttrs, funcAttr);
+        argTypes, args, paramAttrs, funcAttr, op.getOperation());
     if (std::optional<ArrayAttr> optCacheControls =
             getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
       call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 2d9a9d5683756..aeb9e56035653 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -4,8 +4,11 @@
 // and the generic `convert-to-llvm` pass.
 // RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
 
-// CHECK: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
-// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) -> vector<8xi16> {
+// CHECK:      llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+// CHECK-SAME:   !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+// CHECK-SAME:   !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
+// CHECK:      llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME:   %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
 llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
   // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
   // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -14,15 +17,27 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
   // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
-  // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
+  // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) 
+  // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
+  // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
+  // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
+  // CHECK-SAME:   will_return}
+  // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+  // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
   // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
-  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+    <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
+      pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
   llvm.return %loaded_a : vector<8xi16>
 }
 
 // -----
-// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return}
-// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) {
+// CHECK: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
+// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
+// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return}
+// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) {
 llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
   // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
   // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -32,31 +47,60 @@ llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i3
   // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
   // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
   // CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr
-  // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.readonly}) -> ()
-  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+  // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
+  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
+  // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
+  // CHECK-SAME:   "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
+  // CHECK-SAME:   will_return}
+  // CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
+  // CHECK-SAME:    !llvm.ptr {llvm.nonnull, llvm.readonly}) -> ()
+  xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
+    <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
+    : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
   llvm.return
 }
 
 // -----
-// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
-// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
+// CHECK: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
+// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes
+// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
 llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
   // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
   // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]]) {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
-  xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+  // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
+  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]])
+  // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:   memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind,
+  // CHECK-SAME:   sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
+  xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
+    <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
+      cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
+    : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
   llvm.return
 }
 
 // -----
-// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
+// CHECK: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
+// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes
+// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none,
+// CHECK-SAME:   inaccessibleMem = none>, no_unwind, will_return}
 // CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> {
 llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
-  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type = !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
-  %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
+  // CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type =
+  // CHECK-SAME:   !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:   memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind,
+  // CHECK-SAME:   sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return}
+  // CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
+  %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted
+    { shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
+    : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
   llvm.return %c_result : vector<8xf32>
 }
 
@@ -66,17 +110,23 @@ llvm.func @memfence() {
   // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32
-  // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]]) {function_type = !llvm.func<void (i32, i32, i32)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> ()
+  // CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]])
+  // CHECK-SAME: {function_type = !llvm.func<void (i32, i32, i32)>, linkage = #llvm.linkage<external>, no_unwind,
+  // CHECK-SAME:   sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> ()
   xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
   llvm.return
 }
 
 // -----
-// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
+// CHECK: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes
+// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
 // CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) {
 llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
   // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64
-  // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]]) {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
+  // CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]])
+  // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:   memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>,
+  // CHECK-SAME:   no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
   xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
   llvm.return
 }

>From 68093205c83901823dcd3e2fc34dc1daa9fa93c2 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 9 Jul 2025 04:20:53 +0000
Subject: [PATCH 3/5] Address reviewer comment.

---
 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 4605ef78ee50d..d93ba7915a38c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -240,9 +240,7 @@ static LLVM::CallOp createDeviceFunctionCall(
     ArrayRef<Type> argTypes, ArrayRef<Value> args,
     mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
     LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
-  auto moduleOp = rewriter.getBlock()
-                      ->getParentOp()
-                      ->getParentWithTrait<OpTrait::SymbolTable>();
+  auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
   assert(moduleOp && "Expecting module");
   Location loc = op->getLoc();
 

>From 49efd761407cbdfde69693b9eb43f16ab85f5b35 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 9 Jul 2025 17:58:56 +0000
Subject: [PATCH 4/5] Remove trailing comma.

---
 mlir/include/mlir/Conversion/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 78e68632409d8..50c67da91a4af 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1501,7 +1501,7 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
 
 def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
   let summary = "Convert XeVM to LLVM dialect";
-  let dependentDialects = ["LLVM::LLVMDialect", ];
+  let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
 #endif // MLIR_CONVERSION_PASSES

>From 9fc0400c84702cc0e9dc954216a311f140305521 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 10 Jul 2025 14:28:53 +0000
Subject: [PATCH 5/5] Use result type instead of C type for MMA.

---
 mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index d93ba7915a38c..a1dce9270ec68 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -303,13 +303,14 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
 
     Value c = op.getC();
     VectorType cOrigTy = cast<VectorType>(c.getType());
-    assert(cOrigTy == op->getResultTypes()[0] &&
-           "Accumulator and result type mismatch");
+    VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
+    assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
     // OCL builtins encode bfloat16 as int16
     VectorType cTy =
         cOrigTy.getElementType().isBF16()
             ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
             : cOrigTy;
+    VectorType resTy = cTy;
     if (cOrigTy != cTy)
       c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
 
@@ -332,12 +333,12 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
     auto funcAttrs = convergentNoUnwindWillReturnAttrs;
     funcAttrs.memEffectsAttr = memAttr;
     Value result =
-        createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, args, {},
+        createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
                                  funcAttrs, op.getOperation())
             ->getResult(0);
 
-    if (cOrigTy != cTy)
-      result = rewriter.create<LLVM::BitcastOp>(loc, cOrigTy, result);
+    if (resOrigTy != resTy)
+      result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
 
     rewriter.replaceOp(op, result);
     return success();



More information about the Mlir-commits mailing list