[Mlir-commits] [mlir] [MLIR][SPIRV][XeVM] Add MathToXeVM (`math-to-xevm`) pass (PR #159878)

Ian Li llvmlistbot at llvm.org
Thu Oct 2 08:55:16 PDT 2025


https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/159878

>From 8af837573e9325d5f9f1e38bd0e4c510a58d37a1 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 19 Sep 2025 16:44:38 -0700
Subject: [PATCH 01/13] Initial mockup of the MathToXeVM pass

---
 .../mlir/Conversion/MathToXeVM/MathToXeVM.h   |  26 +++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  17 ++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt |  24 +++
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 159 ++++++++++++++++++
 .../Conversion/MathToXeVM/math-to-xevm.mlir   |  22 +++
 .../MathToXeVM/native-spirv-builtins.mlir     |  33 ++++
 8 files changed, 283 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
 create mode 100644 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
 create mode 100644 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
 create mode 100644 mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir

diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000000000..7982aa3769e84
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,26 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..ead4d5c50046d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -84,6 +84,7 @@
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
 #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
 #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 
 namespace mlir {
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..20e0b95cc5c78 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,6 +796,23 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+  let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this?
+  let description = [{
+    This pass converts supported Math ops to XeVM.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "func::FuncDialect",
+    "xevm::XeVMDialect",
+    "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToEmitC
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..3f389359a6a2c
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+// TODO check if everything here is needed
+add_mlir_conversion_library(MLIRMathToXeVM
+  MathToXeVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000000000..e18350219ffe8
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,159 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle
+// native functions/intrinsics that take vector operands.
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
+    : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO: OCL doesn't provide native int intrinsics, but check what happens
+    // when IGC receives a native_exp on ints anyway
+    // TODO: what about vectorization?
+    if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+      return failure();
+
+    arith::FastMathFlags fastFlags = op.getFastmath();
+    if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+      return failure();
+
+    // FIXME: Implement handling for vector sizes/dimensions that are not
+    // supported by SPIRV
+    SmallVector<Type, 1> operandTypes;
+    for (auto operand : adaptor.getOperands()) {
+      if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+        return failure();
+      operandTypes.push_back(operand.getType());
+    }
+    LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+    return success();
+  }
+
+  inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+    if (type.isFloat()) {
+      return true;
+    } else if (auto vecType = dyn_cast<VectorType>(type)) {
+      if (!vecType.getElementType().isFloat())
+        return false;
+      // SPIRV distinguishes between vectors and matrices: OpenCL native math
+      // intrsinics are not compatible with matrices.
+      ArrayRef<int64_t> shape = vecType.getShape();
+      if (shape.size() != 1)
+        return false;
+      // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+      if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+        return true;
+    }
+    return false;
+  }
+
+  LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+    // This function assumes op types have already been validated using
+    // isSPIRVCompatibleFloatOrVec.
+    using LLVM::LLVMFuncOp;
+
+    std::string mangledNativeFunc =
+        "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+    auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+      if (type.isF32())
+        mangledNativeFunc += "f";
+      else if (type.isF16())
+        mangledNativeFunc += "Dh";
+      else if (type.isF64())
+        mangledNativeFunc += "d";
+    };
+
+    for (auto type : operandTypes) {
+      if (auto vecType = dyn_cast<VectorType>(type)) {
+        mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+        appendFloatToMangledFunc(vecType.getElementType());
+      } else
+        appendFloatToMangledFunc(type);
+    }
+
+    auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
+    auto funcOp =
+         SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+    if (funcOp)
+      return funcOp;
+
+    auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
+    assert(parentFunc && "expected there to be a parent function");
+    OpBuilder b(parentFunc);
+
+    // Create a valid global location removing any metadata attached to the
+    // location as debug info metadata inside of a function cannot be used
+    // outside of that function.
+    auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
+    auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+    return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+  }
+
+  const StringRef nativeFunc;
+};
+
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
+    patterns.getContext(), "__spirv_ocl_native_exp");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+    : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+  ConvertMathToXeVMPass() = default;
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+  auto m = getOperation();
+  //MLIRContext *ctx = m.getContext();
+
+  RewritePatternSet patterns(&getContext());
+  populateMathToXeVMConversionPatterns(patterns);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000000000..436d0e0941b9e
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s
+
+module @test_module {
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+  // CHECK-LABEL: func @math_ops
+  func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
+    %result16 = math.exp %arg_f16 fastmath<fast> : f16
+    
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
+    %result64 = math.exp %arg_f64 fastmath<afn> : f64
+
+    // CHECK: math.exp
+    %result_no_fast = math.exp %arg_f64 : f64
+
+    // TODO check fastmath<none>
+
+    func.return %result16, %result64 : f16, f64
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
new file mode 100644
index 0000000000000..f762f4b60f818
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \
+// RUN:             -debug-only=serialize-to-isa 2> %t 
+// RUN: FileCheck --input-file=%t %s
+//
+// MathToXeVM pass generates OpenCL intrinsics function calls when converting
+// Math ops with `fastmath` attr to native function calls. It is assumed that
+// the SPIRV backend would correctly convert these intrinsics calls to OpenCL
+// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp).
+//
+// To ensure this assumption holds, this test verifies that the SPIRV backend
+// behaves as expected.
+
+module @test_ocl_intrinsics attributes {gpu.container_module} {
+  gpu.module @kernel [#xevm.target] {
+    llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} {
+      // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
+      // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
+      %c0_f16 = llvm.mlir.constant(0. : f16) : f16
+      // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
+      // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
+      %c0_f64 = llvm.mlir.constant(0. : f64) : f64
+
+      // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+      %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+      // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+      %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
+
+      llvm.return
+    }
+    llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+    llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+  }
+}

>From 59160d682d89a019178cf1e74433e29a5e3a75ad Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:10:29 -0700
Subject: [PATCH 02/13] clang-format

---
 mlir/include/mlir/Conversion/Passes.h         |  2 +-
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 28 +++++++++++--------
 2 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ead4d5c50046d..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
@@ -84,7 +85,6 @@
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
 #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
 #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 
 namespace mlir {
 
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index e18350219ffe8..e1f1205d6efaa 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -38,8 +38,9 @@ using namespace mlir;
 template <typename Op>
 struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
 
-  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
-    : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+  ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+                           PatternBenefit benefit = 1)
+      : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
 
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
@@ -51,7 +52,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       return failure();
 
     arith::FastMathFlags fastFlags = op.getFastmath();
-    if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+    if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
       return failure();
 
     // FIXME: Implement handling for vector sizes/dimensions that are not
@@ -63,7 +64,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       operandTypes.push_back(operand.getType());
     }
     LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
+                                              adaptor.getOperands());
     return success();
   }
 
@@ -79,13 +81,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       if (shape.size() != 1)
         return false;
       // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
-      if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+      if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+          shape[0] == 16)
         return true;
     }
     return false;
   }
 
-  LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+  LLVM::LLVMFuncOp
+  appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
     // This function assumes op types have already been validated using
     // isSPIRVCompatibleFloatOrVec.
     using LLVM::LLVMFuncOp;
@@ -112,7 +116,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
 
     auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
     auto funcOp =
-         SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+        SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
     if (funcOp)
       return funcOp;
 
@@ -124,17 +128,17 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
     // location as debug info metadata inside of a function cannot be used
     // outside of that function.
     auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
-    auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+    auto globalloc =
+        op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
     return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
   }
 
   const StringRef nativeFunc;
 };
 
-
 void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
-    patterns.getContext(), "__spirv_ocl_native_exp");
+  patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_exp");
 }
 
 namespace {
@@ -147,7 +151,7 @@ struct ConvertMathToXeVMPass
 
 void ConvertMathToXeVMPass::runOnOperation() {
   auto m = getOperation();
-  //MLIRContext *ctx = m.getContext();
+  // MLIRContext *ctx = m.getContext();
 
   RewritePatternSet patterns(&getContext());
   populateMathToXeVMConversionPatterns(patterns);

>From 89f94eadf57b3aeb6fb416ed48f1fc7ba1cf1476 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:18:45 -0700
Subject: [PATCH 03/13] fix cmake

---
 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
index 3f389359a6a2c..711c6876bb168 100644
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -1,4 +1,4 @@
-// TODO check if everything here is needed
+# TODO check if everything here is needed
 add_mlir_conversion_library(MLIRMathToXeVM
   MathToXeVM.cpp
 

>From bf4ed92c054490aa46d1273c0020c8345ce5c320 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 23 Sep 2025 16:03:43 -0700
Subject: [PATCH 04/13] update tests

---
 .../Conversion/MathToXeVM/math-to-xevm.mlir   | 73 +++++++++++++++++--
 .../MathToXeVM/native-spirv-builtins.mlir     | 40 ++++++++++
 2 files changed, 107 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 436d0e0941b9e..00719e9b881d2 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -2,21 +2,82 @@
 
 module @test_module {
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
   // CHECK-LABEL: func @math_ops
-  func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+  func.func @math_ops() {
+
+    %c1_f16 = arith.constant 1. : f16
+    %c1_f32 = arith.constant 1. : f32
+    %c1_f64 = arith.constant 1. : f64
+
+    // CHECK: math.exp
+    %res_normal_f16 = math.exp %c1_f16 : f16
+    // CHECK: math.exp
+    %res_normal_f32 = math.exp %c1_f32 : f32
+    // CHECK: math.exp
+    %res_normal_f64 = math.exp %c1_f64 : f64
 
     // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
-    %result16 = math.exp %arg_f16 fastmath<fast> : f16
+    %res_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
+    %res_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
+    %res_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
     
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
+    %res_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
+    %res_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
     // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
-    %result64 = math.exp %arg_f64 fastmath<afn> : f64
+    %res_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
 
     // CHECK: math.exp
-    %result_no_fast = math.exp %arg_f64 : f64
+    %res_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+    // CHECK: math.exp
+    %res_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+    // CHECK: math.exp
+    %res_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+    %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
+    %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
+    %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
+    %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
+    %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64>
+    %res_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64>
+    %res_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64>
+    %res_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+    %res_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64>
+    %res_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+
+    %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
+    %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
 
-    // TODO check fastmath<none>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32>
+    %res_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16>
+    %res_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+    %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
+    %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
+
+    // CHECK: math.exp
+    %res_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+    // CHECK: math.exp
+    %res_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
 
-    func.return %result16, %result64 : f16, f64
+    return
   }
 }
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index f762f4b60f818..92744c9e165da 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -16,18 +16,58 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
       // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
       // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
       %c0_f16 = llvm.mlir.constant(0. : f16) : f16
+      // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32
+      // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]]
+      %c0_f32 = llvm.mlir.constant(0. : f32) : f32
       // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
       // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
       %c0_f64 = llvm.mlir.constant(0. : f64) : f64
 
+      // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2
+      // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]]
+      %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64>
+      // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3
+      // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]]
+      %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32>
+      // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4
+      // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]]
+      %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64>
+      // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8
+      // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]]
+      %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64>
+      // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16
+      // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
+      %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>     
+
       // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
       %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+      // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
+      %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
       // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
       %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
 
+      // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
+      %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
+      // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
+      %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
+      // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
+      %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
+      // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
+      %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
+      // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
+      %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
+
       llvm.return
     }
     llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+    llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
     llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+    llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64>
+    llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32>
+    llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
+    llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
+    llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
+
+
   }
 }

>From d7ae5f2a35f3d089fdc997b7374c473c227b9182 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 24 Sep 2025 16:40:30 -0700
Subject: [PATCH 05/13] Add support for all other native opts

---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp |  29 +++-
 .../Conversion/MathToXeVM/math-to-xevm.mlir   | 136 +++++++++++++-----
 .../MathToXeVM/native-spirv-builtins.mlir     |  13 ++
 3 files changed, 139 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index e1f1205d6efaa..055cfdf064e4e 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -47,7 +48,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
                   ConversionPatternRewriter &rewriter) const override {
     // TODO: OCL doesn't provide native int intrinsics, but check what happens
     // when IGC receives a native_exp on ints anyway
-    // TODO: what about vectorization?
     if (!isSPIRVCompatibleFloatOrVec(op.getType()))
       return failure();
 
@@ -56,7 +56,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       return failure();
 
     // FIXME: Implement handling for vector sizes/dimensions that are not
-    // supported by SPIRV
+    // supported by SPIRV.
     SmallVector<Type, 1> operandTypes;
     for (auto operand : adaptor.getOperands()) {
       if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
@@ -64,8 +64,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       operandTypes.push_back(operand.getType());
     }
     LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
+    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
                                               adaptor.getOperands());
+    arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+    mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+    callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
     return success();
   }
 
@@ -139,6 +142,26 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
 void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
   patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_exp");
+  patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_cos");
+  patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_exp2");
+  patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_log");
+  patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_log2");
+  patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_log10");
+  patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_powr");
+  patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_rsqrt");
+  patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_sin");
+  patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_sqrt");
+  patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+                                                      "__spirv_ocl_native_tan");
 }
 
 namespace {
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 00719e9b881d2..8e1d20dc94d78 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -4,12 +4,26 @@ module @test_module {
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
-
+  //
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+  //
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+  // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+  // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+  // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+  // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+
   // CHECK-LABEL: func @math_ops
   func.func @math_ops() {
 
@@ -18,32 +32,36 @@ module @test_module {
     %c1_f64 = arith.constant 1. : f64
 
     // CHECK: math.exp
-    %res_normal_f16 = math.exp %c1_f16 : f16
+    %exp_normal_f16 = math.exp %c1_f16 : f16
     // CHECK: math.exp
-    %res_normal_f32 = math.exp %c1_f32 : f32
+    %exp_normal_f32 = math.exp %c1_f32 : f32
     // CHECK: math.exp
-    %res_normal_f64 = math.exp %c1_f64 : f64
-
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
-    %res_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
-    %res_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
-    %res_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+    %exp_normal_f64 = math.exp %c1_f64 : f64
+
+    // Check float operations are converted properly:
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+    %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+    %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+    %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
     
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
-    %res_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
-    %res_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
-    %res_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+    %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
 
     // CHECK: math.exp
-    %res_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+    %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
     // CHECK: math.exp
-    %res_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+    %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
     // CHECK: math.exp
-    %res_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+    %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+    // Check vector operations:
 
     %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
     %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
@@ -51,32 +69,78 @@ module @test_module {
     %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
     %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
 
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64>
-    %res_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64>
-    %res_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64>
-    %res_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
-    %res_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64>
-    %res_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
+    %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
+    %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
+    %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+    %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
+    %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
 
     %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
     %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
 
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32>
-    %res_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
-    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16>
-    %res_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
+    %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
+    %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+    // Check unsupported vector sizes are not converted:
 
     %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
     %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
 
     // CHECK: math.exp
-    %res_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+    %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
     // CHECK: math.exp
-    %res_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+    %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+
+    // Check fastmath flags propagate properly:
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+    %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
+    %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
+    %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
+
+    // Check all other math operations:
+
+    // native_divide(gentype x, gentype y)
+    // TODO: convert arith.divf to arith/native_divide if option is enabled
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
+
+    // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %log_afn_f16 = math.log %c1_f16 fastmath<afn> : f16
+
+    // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %log2_afn_f32 = math.log2 %c1_f32 fastmath<afn> : f32
+
+    // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+    %log10_afn_f64 = math.log10 %c1_f64 fastmath<afn> : f64
+
+    // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16, f16) -> f16
+    %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath<afn> : f16
+
+    // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+    %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath<afn> : f64
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+    %sin_afn_f16 = math.sin %c1_f16 fastmath<afn> : f16
+
+    // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+    %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath<afn> : f32
+
+    // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+    %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
 
     return
   }
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index 92744c9e165da..b83288c7ec99e 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -57,6 +57,19 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
       // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
       %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
 
+
+      // SPIRV backend does not currently handle fastmath flags: The SPIRV
+      // backend would need to generate OpDecorate calls to decorate math ops
+      // with FPFastMathMode/FPFastMathModeINTEL decorations.
+      //
+      // FIXME: When support for fastmath flags in the SPIRV backend is added, 
+      // add tests here to ensure fastmath flags are converted to the correct
+      // OpDecorate calls.
+      // 
+      // See:
+      // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
+      // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
+
       llvm.return
     }
     llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16

>From 9d3a540934345628510df63b1ac69312d76cea8a Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 12:16:24 -0700
Subject: [PATCH 06/13] finish testing for native-spirv-builtins

---
 .../MathToXeVM/native-spirv-builtins.mlir     | 51 +++++++++++++++----
 1 file changed, 40 insertions(+), 11 deletions(-)

diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index b83288c7ec99e..6bc90e34060b4 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -39,25 +39,24 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
       // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
       %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>     
 
-      // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
       %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
-      // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
+      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
       %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
-      // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
       %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
 
-      // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
+      // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
       %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
-      // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
+      // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
       %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
-      // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
+      // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
       %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
-      // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
+      // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
       %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
-      // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
+      // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
       %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
 
-
       // SPIRV backend does not currently handle fastmath flags: The SPIRV
       // backend would need to generate OpDecorate calls to decorate math ops
       // with FPFastMathMode/FPFastMathModeINTEL decorations.
@@ -70,8 +69,30 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
       // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
       // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
 
+      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]]
+      %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]]
+      %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]]
+      %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]]
+      %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+      // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]]
+      %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+      // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]]
+      %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]]
+      %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+      // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]]
+      %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]]
+      %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+      // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
+      %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+
       llvm.return
     }
+    
     llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
     llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
     llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
@@ -80,7 +101,15 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
     llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
     llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
     llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
-
-
+    llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+    llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+    llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+    llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+    llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64>
+    llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+    llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+    llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+    llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+    llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
   }
 }

>From 0232c265a210231da53081548392247ab27017f3 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 14:59:40 -0700
Subject: [PATCH 07/13] Accomodate for arith.divf

---
 .../mlir/Conversion/MathToXeVM/MathToXeVM.h   |  2 +-
 mlir/include/mlir/Conversion/Passes.td        | 13 +++++++++++--
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 19 ++++++++++---------
 .../Conversion/MathToXeVM/math-to-xevm.mlir   | 13 ++++++++++++-
 .../MathToXeVM/native-spirv-builtins.mlir     |  5 ++++-
 5 files changed, 38 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
index 7982aa3769e84..6bb69361dcb6d 100644
--- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -20,7 +20,7 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Populate the given list with patterns that convert from Math to XeVM calls.
-void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns);
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 20e0b95cc5c78..976e1b6b183e1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -801,10 +801,19 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
-  let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this?
+  let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents";
   let description = [{
-    This pass converts supported Math ops to XeVM.
+    This pass converts supported math ops marked with the `afn` fastmath flag
+    to function calls for OpenCL `native_` math intrinsics: These intrinsics
+    are typically mapped directly to native device instructions, often resulting
+    in better performance. However, the precision/error of these intrinsics
+    are implementation-defined, and thus math ops are only converted when they
+    have the `afn` fastmath flag enabled.
   }];
+  let options = [
+    Option<"convertArith", "convert-arith", "bool", /*default=*/"true",
+    "Convert supported Arith ops (e.g. arith.divf) as well.">
+  ];
   let dependentDialects = [
     "arith::ArithDialect",
     "func::FuncDialect",
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 055cfdf064e4e..b75f8d3640a41 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -46,8 +46,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // TODO: OCL doesn't provide native int intrinsics, but check what happens
-    // when IGC receives a native_exp on ints anyway
     if (!isSPIRVCompatibleFloatOrVec(op.getType()))
       return failure();
 
@@ -55,10 +53,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
     if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
       return failure();
 
-    // FIXME: Implement handling for vector sizes/dimensions that are not
-    // supported by SPIRV.
     SmallVector<Type, 1> operandTypes;
     for (auto operand : adaptor.getOperands()) {
+      // This pass only supports operations on vectors that are already in SPIRV
+      // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+      // supported vetor sizes are done in other blocking optimization passes.
       if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
         return failure();
       operandTypes.push_back(operand.getType());
@@ -128,7 +127,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
     OpBuilder b(parentFunc);
 
     // Create a valid global location removing any metadata attached to the
-    // location as debug info metadata inside of a function cannot be used
+    // location, as debug info metadata inside of a function cannot be used
     // outside of that function.
     auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
     auto globalloc =
@@ -139,7 +138,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
   const StringRef nativeFunc;
 };
 
-void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) {
   patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_exp");
   patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
@@ -162,22 +161,24 @@ void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
                                                       "__spirv_ocl_native_sqrt");
   patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_tan");
+  if (convertArith)
+    patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(patterns.getContext(),
+                                                          "__spirv_ocl_native_divide");
 }
 
 namespace {
 struct ConvertMathToXeVMPass
     : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
-  ConvertMathToXeVMPass() = default;
+  using Base::Base;
   void runOnOperation() override;
 };
 } // namespace
 
 void ConvertMathToXeVMPass::runOnOperation() {
   auto m = getOperation();
-  // MLIRContext *ctx = m.getContext();
 
   RewritePatternSet patterns(&getContext());
-  populateMathToXeVMConversionPatterns(patterns);
+  populateMathToXeVMConversionPatterns(patterns, convertArith);
   ConversionTarget target(getContext());
   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
                          vector::VectorDialect, LLVM::LLVMDialect>();
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 8e1d20dc94d78..ba5de228da411 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s
+// RUN: mlir-opt %s -convert-math-to-xevm \
+// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH' 
+// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
+// RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
 
 module @test_module {
   // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
@@ -23,6 +26,7 @@ module @test_module {
   // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
   // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
   // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+  // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
 
   // CHECK-LABEL: func @math_ops
   func.func @math_ops() {
@@ -142,6 +146,13 @@ module @test_module {
     // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
     %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
 
+    %c6_9_f32 = arith.constant 6.9 : f32
+    %c7_f32 = arith.constant 7. : f32
+
+    // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+    // CHECK-NO-ARITH: arith.divf
+    %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath<afn> : f32
+
     return
   }
 }
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index 6bc90e34060b4..2492adafd6a50 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -89,10 +89,12 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
       %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
       // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
       %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+      // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]]
+      %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
 
       llvm.return
     }
-    
+
     llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
     llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
     llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
@@ -111,5 +113,6 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
     llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
     llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
     llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+    llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
   }
 }

>From 3887fe5fe68d0f3caa4d1c2a850684ae97543140 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 15:12:54 -0700
Subject: [PATCH 08/13] clang-format

---
 .../mlir/Conversion/MathToXeVM/MathToXeVM.h   |  3 +-
 mlir/include/mlir/Conversion/Passes.td        | 10 ++---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 37 ++++++++++---------
 3 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
index 6bb69361dcb6d..91d3c92fd6296 100644
--- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -20,7 +20,8 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Populate the given list with patterns that convert from Math to XeVM calls.
-void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith);
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+                                          bool convertArith);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 976e1b6b183e1..5817babf68ddb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -801,7 +801,8 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
-  let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+  let summary =
+      "Convert (fast) math operations to native XeVM/SPIRV equivalents";
   let description = [{
     This pass converts supported math ops marked with the `afn` fastmath flag
     to function calls for OpenCL `native_` math intrinsics: These intrinsics
@@ -810,10 +811,9 @@ def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
     are implementation-defined, and thus math ops are only converted when they
     have the `afn` fastmath flag enabled.
   }];
-  let options = [
-    Option<"convertArith", "convert-arith", "bool", /*default=*/"true",
-    "Convert supported Arith ops (e.g. arith.divf) as well.">
-  ];
+  let options = [Option<
+      "convertArith", "convert-arith", "bool", /*default=*/"true",
+      "Convert supported Arith ops (e.g. arith.divf) as well.">];
   let dependentDialects = [
     "arith::ArithDialect",
     "func::FuncDialect",
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index b75f8d3640a41..46833735a79dd 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -63,8 +63,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       operandTypes.push_back(operand.getType());
     }
     LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
-    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
-                                              adaptor.getOperands());
+    auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+        op, funcOp, adaptor.getOperands());
     arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
     mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
     callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
@@ -138,32 +138,33 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
   const StringRef nativeFunc;
 };
 
-void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) {
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+                                                bool convertArith) {
   patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_exp");
   patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_cos");
-  patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_exp2");
+  patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+      patterns.getContext(), "__spirv_ocl_native_exp2");
   patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_log");
-  patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_log2");
-  patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_log10");
-  patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_powr");
-  patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_rsqrt");
+  patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+      patterns.getContext(), "__spirv_ocl_native_log2");
+  patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+      patterns.getContext(), "__spirv_ocl_native_log10");
+  patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+      patterns.getContext(), "__spirv_ocl_native_powr");
+  patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+      patterns.getContext(), "__spirv_ocl_native_rsqrt");
   patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_sin");
-  patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(patterns.getContext(),
-                                                      "__spirv_ocl_native_sqrt");
+  patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+      patterns.getContext(), "__spirv_ocl_native_sqrt");
   patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
                                                       "__spirv_ocl_native_tan");
   if (convertArith)
-    patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(patterns.getContext(),
-                                                          "__spirv_ocl_native_divide");
+    patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+        patterns.getContext(), "__spirv_ocl_native_divide");
 }
 
 namespace {

>From 31c911e12f99d15879d1b33cccb6a5c1c622a54c Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 26 Sep 2025 14:19:38 -0700
Subject: [PATCH 09/13] remove todos

---
 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt     | 5 -----
 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir | 3 ---
 2 files changed, 8 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
index 711c6876bb168..95aaba31a993e 100644
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -1,4 +1,3 @@
-# TODO check if everything here is needed
 add_mlir_conversion_library(MLIRMathToXeVM
   MathToXeVM.cpp
 
@@ -12,13 +11,9 @@ add_mlir_conversion_library(MLIRMathToXeVM
   Core
 
   LINK_LIBS PUBLIC
-  MLIRDialectUtils
-  MLIRFuncDialect
-  MLIRGPUToGPURuntimeTransforms
   MLIRMathDialect
   MLIRLLVMCommonConversion
   MLIRPass
   MLIRTransformUtils
   MLIRVectorDialect
-  MLIRVectorUtils
   )
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index ba5de228da411..e1d3b2615e121 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -113,9 +113,6 @@ module @test_module {
 
     // Check all other math operations:
 
-    // native_divide(gentype x, gentype y)
-    // TODO: convert arith.divf to arith/native_divide if option is enabled
-
     // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
     %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
 

>From 20ac59571ab178cd669028a9dd0a1b8f23df598b Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:25:59 -0700
Subject: [PATCH 10/13] Address reviewer comments

---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 59 ++++++++-----------
 .../Conversion/MathToXeVM/math-to-xevm.mlir   | 42 ++++++-------
 2 files changed, 44 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 46833735a79dd..0c1f9d39e72a2 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -50,21 +51,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       return failure();
 
     arith::FastMathFlags fastFlags = op.getFastmath();
-    if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
-      return failure();
+    if (!(static_cast<uint32_t>(fastFlags) & static_cast<uint32_t>(arith::FastMathFlags::afn)))
+      return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
 
     SmallVector<Type, 1> operandTypes;
     for (auto operand : adaptor.getOperands()) {
       // This pass only supports operations on vectors that are already in SPIRV
       // supported vector sizes: Distributing unsupported vector sizes to SPIRV
-      // supported vetor sizes are done in other blocking optimization passes.
+      // supported vector sizes are done in other blocking optimization passes.
       if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
-        return failure();
+        return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type");
       operandTypes.push_back(operand.getType());
     }
-    LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+
+    auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+    auto funcOpRes =
+      LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType());
+    assert(!failed(funcOpRes));
+    LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
     auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         op, funcOp, adaptor.getOperands());
+    // Preserve the fastmath flags in our MLIR op for later use: We need to
+    // convert our MLIR fastmath attrs into something compatible with llvm.
     arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
     mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
     callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
@@ -90,49 +99,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
     return false;
   }
 
-  LLVM::LLVMFuncOp
-  appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
-    // This function assumes op types have already been validated using
-    // isSPIRVCompatibleFloatOrVec.
-    using LLVM::LLVMFuncOp;
 
-    std::string mangledNativeFunc =
+  inline std::string getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+    std::string mangledFuncName =
         "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
 
-    auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+    auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
       if (type.isF32())
-        mangledNativeFunc += "f";
+        mangledFuncName += "f";
       else if (type.isF16())
-        mangledNativeFunc += "Dh";
+        mangledFuncName += "Dh";
       else if (type.isF64())
-        mangledNativeFunc += "d";
+        mangledFuncName += "d";
     };
 
     for (auto type : operandTypes) {
       if (auto vecType = dyn_cast<VectorType>(type)) {
-        mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+        mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
         appendFloatToMangledFunc(vecType.getElementType());
       } else
         appendFloatToMangledFunc(type);
     }
 
-    auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
-    auto funcOp =
-        SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
-    if (funcOp)
-      return funcOp;
-
-    auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
-    assert(parentFunc && "expected there to be a parent function");
-    OpBuilder b(parentFunc);
-
-    // Create a valid global location removing any metadata attached to the
-    // location, as debug info metadata inside of a function cannot be used
-    // outside of that function.
-    auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
-    auto globalloc =
-        op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
-    return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+    return mangledFuncName;
   }
 
   const StringRef nativeFunc;
@@ -176,13 +165,11 @@ struct ConvertMathToXeVMPass
 } // namespace
 
 void ConvertMathToXeVMPass::runOnOperation() {
-  auto m = getOperation();
-
   RewritePatternSet patterns(&getContext());
   populateMathToXeVMConversionPatterns(patterns, convertArith);
   ConversionTarget target(getContext());
   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
                          vector::VectorDialect, LLVM::LLVMDialect>();
-  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+  if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index e1d3b2615e121..04b5906489d00 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -4,29 +4,29 @@
 // RUN:   | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
 
 module @test_module {
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
   //
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
   //
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
-  // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
-  // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
-  // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
-  // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
-  // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
-  // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
-  // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
-  // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+  // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+  // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+  // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+  // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+  // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
 
   // CHECK-LABEL: func @math_ops
   func.func @math_ops() {

>From 7b8d0297f8c973d55c445b243295204497ef0d46 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:27:58 -0700
Subject: [PATCH 11/13] clang-format

---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 0c1f9d39e72a2..825a7bb79242a 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -51,7 +51,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       return failure();
 
     arith::FastMathFlags fastFlags = op.getFastmath();
-    if (!(static_cast<uint32_t>(fastFlags) & static_cast<uint32_t>(arith::FastMathFlags::afn)))
+    if (!(static_cast<uint32_t>(fastFlags) &
+          static_cast<uint32_t>(arith::FastMathFlags::afn)))
       return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
 
     SmallVector<Type, 1> operandTypes;
@@ -60,13 +61,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
       // supported vector sizes: Distributing unsupported vector sizes to SPIRV
       // supported vector sizes are done in other blocking optimization passes.
       if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
-        return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type");
+        return rewriter.notifyMatchFailure(
+            op, "no equivalent native operation for operand type");
       operandTypes.push_back(operand.getType());
     }
 
     auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
-    auto funcOpRes =
-      LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType());
+    auto funcOpRes = LLVM::lookupOrCreateFn(
+        rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+        operandTypes, op.getType());
     assert(!failed(funcOpRes));
     LLVM::LLVMFuncOp funcOp = funcOpRes.value();
 
@@ -99,8 +102,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
     return false;
   }
 
-
-  inline std::string getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+  inline std::string
+  getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
     std::string mangledFuncName =
         "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
 
@@ -170,6 +173,7 @@ void ConvertMathToXeVMPass::runOnOperation() {
   ConversionTarget target(getContext());
   target.addLegalDialect<BuiltinDialect, func::FuncDialect,
                          vector::VectorDialect, LLVM::LLVMDialect>();
-  if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
     signalPassFailure();
 }

>From 17ad71c766c02550b934d6ca9fa817e7787713c8 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:42:48 -0700
Subject: [PATCH 12/13]  improve comment

---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 825a7bb79242a..bbac0af19f0b4 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -75,8 +75,9 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
 
     auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
         op, funcOp, adaptor.getOperands());
-    // Preserve the fastmath flags in our MLIR op for later use: We need to
-    // convert our MLIR fastmath attrs into something compatible with llvm.
+    // Preserve fastmath flags in our MLIR op when converting to llvm function
+    // calls, in order to allow further fastmath optimizations: We thus need to
+    // convert arith fastmath attrs into attrs recognized by llvm.
     arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
     mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
     callOp->setAttr(fastAttr.getName(), fastAttr.getValue());

>From 203c1f08a1c2e54c3ae20ce3c18c6fe71e910239 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 2 Oct 2025 08:54:38 -0700
Subject: [PATCH 13/13] Improve logging

---
 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index bbac0af19f0b4..156b9a38d07eb 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 
 #include "../GPUCommon/GPUOpsLowering.h"
 #include "../GPUCommon/OpToFuncCallLowering.h"
@@ -57,13 +58,14 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
 
     SmallVector<Type, 1> operandTypes;
     for (auto operand : adaptor.getOperands()) {
+      Type opTy = operand.getType();
       // This pass only supports operations on vectors that are already in SPIRV
       // supported vector sizes: Distributing unsupported vector sizes to SPIRV
       // supported vector sizes are done in other blocking optimization passes.
-      if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+      if (!isSPIRVCompatibleFloatOrVec(opTy))
         return rewriter.notifyMatchFailure(
-            op, "no equivalent native operation for operand type");
-      operandTypes.push_back(operand.getType());
+            op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+      operandTypes.push_back(opTy);
     }
 
     auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();



More information about the Mlir-commits mailing list