[Mlir-commits] [mlir] [MLIR][SPIRV][XeVM] Add MathToXeVM (`math-to-xevm`) pass (PR #159878)
Ian Li
llvmlistbot at llvm.org
Thu Oct 2 08:55:16 PDT 2025
https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/159878
>From 8af837573e9325d5f9f1e38bd0e4c510a58d37a1 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 19 Sep 2025 16:44:38 -0700
Subject: [PATCH 01/13] Initial mockup of the MathToXeVM pass
---
.../mlir/Conversion/MathToXeVM/MathToXeVM.h | 26 +++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 17 ++
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 24 +++
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 159 ++++++++++++++++++
.../Conversion/MathToXeVM/math-to-xevm.mlir | 22 +++
.../MathToXeVM/native-spirv-builtins.mlir | 33 ++++
8 files changed, 283 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
create mode 100644 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
create mode 100644 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
create mode 100644 mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000000000..7982aa3769e84
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,26 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..ead4d5c50046d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -84,6 +84,7 @@
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..20e0b95cc5c78 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,6 +796,23 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
+//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+ let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this?
+ let description = [{
+ This pass converts supported Math ops to XeVM.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "xevm::XeVMDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..3f389359a6a2c
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+// TODO check if everything here is needed
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000000000..e18350219ffe8
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,159 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle
+// native functions/intrinsics that take vector operands.
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO: OCL doesn't provide native int intrinsics, but check what happens
+ // when IGC receives a native_exp on ints anyway
+ // TODO: what about vectorization?
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+ return failure();
+
+ // FIXME: Implement handling for vector sizes/dimensions that are not
+ // supported by SPIRV
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+ return failure();
+ operandTypes.push_back(operand.getType());
+ }
+ LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat()) {
+ return true;
+ } else if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+ // This function assumes op types have already been validated using
+ // isSPIRVCompatibleFloatOrVec.
+ using LLVM::LLVMFuncOp;
+
+ std::string mangledNativeFunc =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+ if (type.isF32())
+ mangledNativeFunc += "f";
+ else if (type.isF16())
+ mangledNativeFunc += "Dh";
+ else if (type.isF64())
+ mangledNativeFunc += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+ if (funcOp)
+ return funcOp;
+
+ auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ // Create a valid global location removing any metadata attached to the
+ // location as debug info metadata inside of a function cannot be used
+ // outside of that function.
+ auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
+ auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+ return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+ }
+
+ const StringRef nativeFunc;
+};
+
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
+ patterns.getContext(), "__spirv_ocl_native_exp");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ ConvertMathToXeVMPass() = default;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ auto m = getOperation();
+ //MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000000000..436d0e0941b9e
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ // CHECK-LABEL: func @math_ops
+ func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
+ %result16 = math.exp %arg_f16 fastmath<fast> : f16
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
+ %result64 = math.exp %arg_f64 fastmath<afn> : f64
+
+ // CHECK: math.exp
+ %result_no_fast = math.exp %arg_f64 : f64
+
+ // TODO check fastmath<none>
+
+ func.return %result16, %result64 : f16, f64
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
new file mode 100644
index 0000000000000..f762f4b60f818
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \
+// RUN: -debug-only=serialize-to-isa 2> %t
+// RUN: FileCheck --input-file=%t %s
+//
+// MathToXeVM pass generates OpenCL intrinsics function calls when converting
+// Math ops with `fastmath` attr to native function calls. It is assumed that
+// the SPIRV backend would correctly convert these intrinsics calls to OpenCL
+// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp).
+//
+// To ensure this assumption holds, this test verifies that the SPIRV backend
+// behaves as expected.
+
+module @test_ocl_intrinsics attributes {gpu.container_module} {
+ gpu.module @kernel [#xevm.target] {
+ llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} {
+ // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
+ // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
+ %c0_f16 = llvm.mlir.constant(0. : f16) : f16
+ // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
+ // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
+ %c0_f64 = llvm.mlir.constant(0. : f64) : f64
+
+ // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+ %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+ // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+ %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
+
+ llvm.return
+ }
+ llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ }
+}
>From 59160d682d89a019178cf1e74433e29a5e3a75ad Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:10:29 -0700
Subject: [PATCH 02/13] clang-format
---
mlir/include/mlir/Conversion/Passes.h | 2 +-
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 28 +++++++++++--------
2 files changed, 17 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ead4d5c50046d..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
@@ -84,7 +85,6 @@
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
namespace mlir {
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index e18350219ffe8..e1f1205d6efaa 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -38,8 +38,9 @@ using namespace mlir;
template <typename Op>
struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
- ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
- : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
@@ -51,7 +52,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return failure();
arith::FastMathFlags fastFlags = op.getFastmath();
- if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+ if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
return failure();
// FIXME: Implement handling for vector sizes/dimensions that are not
@@ -63,7 +64,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
operandTypes.push_back(operand.getType());
}
LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
+ adaptor.getOperands());
return success();
}
@@ -79,13 +81,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
if (shape.size() != 1)
return false;
// SPIRV only allows vectors of size 2, 3, 4, 8, 16.
- if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
return true;
}
return false;
}
- LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+ LLVM::LLVMFuncOp
+ appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
// This function assumes op types have already been validated using
// isSPIRVCompatibleFloatOrVec.
using LLVM::LLVMFuncOp;
@@ -112,7 +116,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
auto funcOp =
- SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
if (funcOp)
return funcOp;
@@ -124,17 +128,17 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
- auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+ auto globalloc =
+ op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
}
const StringRef nativeFunc;
};
-
void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
- patterns.getContext(), "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
}
namespace {
@@ -147,7 +151,7 @@ struct ConvertMathToXeVMPass
void ConvertMathToXeVMPass::runOnOperation() {
auto m = getOperation();
- //MLIRContext *ctx = m.getContext();
+ // MLIRContext *ctx = m.getContext();
RewritePatternSet patterns(&getContext());
populateMathToXeVMConversionPatterns(patterns);
>From 89f94eadf57b3aeb6fb416ed48f1fc7ba1cf1476 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:18:45 -0700
Subject: [PATCH 03/13] fix cmake
---
mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
index 3f389359a6a2c..711c6876bb168 100644
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -1,4 +1,4 @@
-// TODO check if everything here is needed
+# TODO check if everything here is needed
add_mlir_conversion_library(MLIRMathToXeVM
MathToXeVM.cpp
>From bf4ed92c054490aa46d1273c0020c8345ce5c320 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Tue, 23 Sep 2025 16:03:43 -0700
Subject: [PATCH 04/13] update tests
---
.../Conversion/MathToXeVM/math-to-xevm.mlir | 73 +++++++++++++++++--
.../MathToXeVM/native-spirv-builtins.mlir | 40 ++++++++++
2 files changed, 107 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 436d0e0941b9e..00719e9b881d2 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -2,21 +2,82 @@
module @test_module {
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
// CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
// CHECK-LABEL: func @math_ops
- func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ func.func @math_ops() {
+
+ %c1_f16 = arith.constant 1. : f16
+ %c1_f32 = arith.constant 1. : f32
+ %c1_f64 = arith.constant 1. : f64
+
+ // CHECK: math.exp
+ %res_normal_f16 = math.exp %c1_f16 : f16
+ // CHECK: math.exp
+ %res_normal_f32 = math.exp %c1_f32 : f32
+ // CHECK: math.exp
+ %res_normal_f64 = math.exp %c1_f64 : f64
// CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
- %result16 = math.exp %arg_f16 fastmath<fast> : f16
+ %res_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
+ %res_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
+ %res_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
+ %res_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
+ %res_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
// CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
- %result64 = math.exp %arg_f64 fastmath<afn> : f64
+ %res_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
// CHECK: math.exp
- %result_no_fast = math.exp %arg_f64 : f64
+ %res_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+ // CHECK: math.exp
+ %res_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+ // CHECK: math.exp
+ %res_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+ %v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
+ %v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
+ %v4_c1_f64 = arith.constant dense<1.> : vector<4xf64>
+ %v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
+ %v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64>
+ %res_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64>
+ %res_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64>
+ %res_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+ %res_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64>
+ %res_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+
+ %v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
+ %v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
- // TODO check fastmath<none>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32>
+ %res_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16>
+ %res_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+ %v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
+ %v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
+
+ // CHECK: math.exp
+ %res_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+ // CHECK: math.exp
+ %res_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
- func.return %result16, %result64 : f16, f64
+ return
}
}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index f762f4b60f818..92744c9e165da 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -16,18 +16,58 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
// CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
// CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
%c0_f16 = llvm.mlir.constant(0. : f16) : f16
+ // CHECK-DAG: %[[F32T:.+]] = OpTypeFloat 32
+ // CHECK-DAG: %[[ZERO_F32:.+]] = OpConstantNull %[[F32T]]
+ %c0_f32 = llvm.mlir.constant(0. : f32) : f32
// CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
// CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
%c0_f64 = llvm.mlir.constant(0. : f64) : f64
+ // CHECK-DAG: %[[V2F64T:.+]] = OpTypeVector %[[F64T]] 2
+ // CHECK-DAG: %[[V2_ZERO_F64:.+]] = OpConstantNull %[[V2F64T]]
+ %v2_c0_f64 = llvm.mlir.constant(dense<0.> : vector<2xf64>) : vector<2xf64>
+ // CHECK-DAG: %[[V3F32T:.+]] = OpTypeVector %[[F32T]] 3
+ // CHECK-DAG: %[[V3_ZERO_F32:.+]] = OpConstantNull %[[V3F32T]]
+ %v3_c0_f32 = llvm.mlir.constant(dense<0.> : vector<3xf32>) : vector<3xf32>
+ // CHECK-DAG: %[[V4F64T:.+]] = OpTypeVector %[[F64T]] 4
+ // CHECK-DAG: %[[V4_ZERO_F64:.+]] = OpConstantNull %[[V4F64T]]
+ %v4_c0_f64 = llvm.mlir.constant(dense<0.> : vector<4xf64>) : vector<4xf64>
+ // CHECK-DAG: %[[V8F64T:.+]] = OpTypeVector %[[F64T]] 8
+ // CHECK-DAG: %[[V8_ZERO_F64:.+]] = OpConstantNull %[[V8F64T]]
+ %v8_c0_f64 = llvm.mlir.constant(dense<0.> : vector<8xf64>) : vector<8xf64>
+ // CHECK-DAG: %[[V16F16T:.+]] = OpTypeVector %[[F16T]] 16
+ // CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
+ %v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>
+
// CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
%exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+ // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
+ %exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
// CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
%exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
+ // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
+ %exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
+ // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
+ %exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
+ // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
+ %exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
+ // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
+ %exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
+ // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
+ %exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
+
llvm.return
}
llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ llvm.func @_Z22__spirv_ocl_native_expDv2_f64(vector<2xf64>) -> vector<2xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv3_f32(vector<3xf32>) -> vector<3xf32>
+ llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
+ llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
+
+
}
}
>From d7ae5f2a35f3d089fdc997b7374c473c227b9182 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 24 Sep 2025 16:40:30 -0700
Subject: [PATCH 05/13] Add support for all other native opts
---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 29 +++-
.../Conversion/MathToXeVM/math-to-xevm.mlir | 136 +++++++++++++-----
.../MathToXeVM/native-spirv-builtins.mlir | 13 ++
3 files changed, 139 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index e1f1205d6efaa..055cfdf064e4e 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -47,7 +48,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
ConversionPatternRewriter &rewriter) const override {
// TODO: OCL doesn't provide native int intrinsics, but check what happens
// when IGC receives a native_exp on ints anyway
- // TODO: what about vectorization?
if (!isSPIRVCompatibleFloatOrVec(op.getType()))
return failure();
@@ -56,7 +56,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return failure();
// FIXME: Implement handling for vector sizes/dimensions that are not
- // supported by SPIRV
+ // supported by SPIRV.
SmallVector<Type, 1> operandTypes;
for (auto operand : adaptor.getOperands()) {
if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
@@ -64,8 +64,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
operandTypes.push_back(operand.getType());
}
LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
adaptor.getOperands());
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
return success();
}
@@ -139,6 +142,26 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
"__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(patterns.getContext(),
+ "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(patterns.getContext(),
+ "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(patterns.getContext(),
+ "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(patterns.getContext(),
+ "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(patterns.getContext(),
+ "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
}
namespace {
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 00719e9b881d2..8e1d20dc94d78 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -4,12 +4,26 @@ module @test_module {
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
// CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
// CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
-
+ //
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+ //
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+ // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+
// CHECK-LABEL: func @math_ops
func.func @math_ops() {
@@ -18,32 +32,36 @@ module @test_module {
%c1_f64 = arith.constant 1. : f64
// CHECK: math.exp
- %res_normal_f16 = math.exp %c1_f16 : f16
+ %exp_normal_f16 = math.exp %c1_f16 : f16
// CHECK: math.exp
- %res_normal_f32 = math.exp %c1_f32 : f32
+ %exp_normal_f32 = math.exp %c1_f32 : f32
// CHECK: math.exp
- %res_normal_f64 = math.exp %c1_f64 : f64
-
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
- %res_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
- %res_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
- %res_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
+ %exp_normal_f64 = math.exp %c1_f64 : f64
+
+ // Check float operations are converted properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fast_f16 = math.exp %c1_f16 fastmath<fast> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+ %exp_fast_f32 = math.exp %c1_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+ %exp_fast_f64 = math.exp %c1_f64 fastmath<fast> : f64
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
- %res_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) : (f32) -> f32
- %res_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
- %res_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %exp_afn_f16 = math.exp %c1_f16 fastmath<afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp_afn_f32 = math.exp %c1_f32 fastmath<afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %exp_afn_f64 = math.exp %c1_f64 fastmath<afn> : f64
// CHECK: math.exp
- %res_none_f16 = math.exp %c1_f16 fastmath<none> : f16
+ %exp_none_f16 = math.exp %c1_f16 fastmath<none> : f16
// CHECK: math.exp
- %res_none_f32 = math.exp %c1_f32 fastmath<none> : f32
+ %exp_none_f32 = math.exp %c1_f32 fastmath<none> : f32
// CHECK: math.exp
- %res_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+ %exp_none_f64 = math.exp %c1_f64 fastmath<none> : f64
+
+ // Check vector operations:
%v2_c1_f64 = arith.constant dense<1.> : vector<2xf64>
%v3_c1_f64 = arith.constant dense<1.> : vector<3xf64>
@@ -51,32 +69,78 @@ module @test_module {
%v8_c1_f64 = arith.constant dense<1.> : vector<8xf64>
%v16_c1_f64 = arith.constant dense<1.> : vector<16xf64>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) : (vector<2xf64>) -> vector<2xf64>
- %res_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) : (vector<3xf64>) -> vector<3xf64>
- %res_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) : (vector<4xf64>) -> vector<4xf64>
- %res_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
- %res_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) : (vector<16xf64>) -> vector<16xf64>
- %res_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv2_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<2xf64>) -> vector<2xf64>
+ %exp_v2_f64 = math.exp %v2_c1_f64 fastmath<afn> : vector<2xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv3_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<3xf64>) -> vector<3xf64>
+ %exp_v3_f64 = math.exp %v3_c1_f64 fastmath<afn> : vector<3xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<4xf64>) -> vector<4xf64>
+ %exp_v4_f64 = math.exp %v4_c1_f64 fastmath<afn> : vector<4xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv8_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+ %exp_v8_f64 = math.exp %v8_c1_f64 fastmath<afn> : vector<8xf64>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf64>) -> vector<16xf64>
+ %exp_v16_f64 = math.exp %v16_c1_f64 fastmath<afn> : vector<16xf64>
%v16_c1_f32 = arith.constant dense<1.> : vector<16xf32>
%v4_c1_f16 = arith.constant dense<1.> : vector<4xf16>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) : (vector<16xf32>) -> vector<16xf32>
- %res_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
- // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) : (vector<4xf16>) -> vector<4xf16>
- %res_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv16_f(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<16xf32>) -> vector<16xf32>
+ %exp_v16_f32 = math.exp %v16_c1_f32 fastmath<fast> : vector<16xf32>
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDv4_Dh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf16>) -> vector<4xf16>
+ %exp_v4_f16 = math.exp %v4_c1_f16 fastmath<fast> : vector<4xf16>
+
+ // Check unsupported vector sizes are not converted:
%v5_c1_f64 = arith.constant dense<1.> : vector<5xf64>
%v32_c1_f64 = arith.constant dense<1.> : vector<32xf64>
// CHECK: math.exp
- %res_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
+ %exp_v5_f64 = math.exp %v5_c1_f64 fastmath<afn> : vector<5xf64>
// CHECK: math.exp
- %res_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+ %exp_v32_f64 = math.exp %v32_c1_f64 fastmath<afn> : vector<32xf64>
+
+ // Check fastmath flags propagate properly:
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f16) -> f16
+ %exp_fastmath_all_f16 = math.exp %c1_f16 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f16
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf, nsz, arcp, contract, afn>} : (f32) -> f32
+ %exp_fastmath_most_f32 = math.exp %c1_f32 fastmath<nnan,ninf,nsz,arcp,contract,afn> : f32
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expf(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, afn, reassoc>} : (f32) -> f32
+ %exp_afn_reassoc_nnan_f32 = math.exp %c1_f32 fastmath<afn,reassoc,nnan> : f32
+
+ // Check all other math operations:
+
+ // native_divide(gentype x, gentype y)
+ // TODO: convert arith.divf to arith/native_divide if option is enabled
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_exp2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %exp2_afn_f32 = math.exp2 %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_logDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %log_afn_f16 = math.log %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_log2f(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %log2_afn_f32 = math.log2 %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z24__spirv_ocl_native_log10d(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %log10_afn_f64 = math.log10 %c1_f64 fastmath<afn> : f64
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_powrDhDh(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16, f16) -> f16
+ %powr_afn_f16 = math.powf %c1_f16, %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z24__spirv_ocl_native_rsqrtd(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %rsqrt_afn_f64 = math.rsqrt %c1_f64 fastmath<afn> : f64
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_sinDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ %sin_afn_f16 = math.sin %c1_f16 fastmath<afn> : f16
+
+ // CHECK: llvm.call @_Z23__spirv_ocl_native_sqrtf(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ %sqrt_afn_f32 = math.sqrt %c1_f32 fastmath<afn> : f32
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ %tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
return
}
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index 92744c9e165da..b83288c7ec99e 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -57,6 +57,19 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
// CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
%exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
+
+ // SPIRV backend does not currently handle fastmath flags: The SPIRV
+ // backend would need to generate OpDecorate calls to decorate math ops
+ // with FPFastMathMode/FPFastMathModeINTEL decorations.
+ //
+ // FIXME: When support for fastmath flags in the SPIRV backend is added,
+ // add tests here to ensure fastmath flags are converted to the correct
+ // OpDecorate calls.
+ //
+ // See:
+ // - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
+ // - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
+
llvm.return
}
llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
>From 9d3a540934345628510df63b1ac69312d76cea8a Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 12:16:24 -0700
Subject: [PATCH 06/13] finish testing for native-spirv-builtins
---
.../MathToXeVM/native-spirv-builtins.mlir | 51 +++++++++++++++----
1 file changed, 40 insertions(+), 11 deletions(-)
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index b83288c7ec99e..6bc90e34060b4 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -39,25 +39,24 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
// CHECK-DAG: %[[V16_ZERO_F16:.+]] = OpConstantNull %[[V16F16T]]
%v16_c0_f16 = llvm.mlir.constant(dense<0.> : vector<16xf16>) : vector<16xf16>
- // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
%exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
- // CHECK: %{{.+}} = OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp %[[ZERO_F32]]
%exp_f32 = llvm.call @_Z22__spirv_ocl_native_expf(%c0_f32) : (f32) -> f32
- // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
%exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
- // CHECK: %{{.+}} = OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
+ // CHECK: OpExtInst %[[V2F64T]] %{{.+}} native_exp %[[V2_ZERO_F64]]
%exp_v2_f64 = llvm.call @_Z22__spirv_ocl_native_expDv2_f64(%v2_c0_f64) : (vector<2xf64>) -> vector<2xf64>
- // CHECK: %{{.+}} = OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
+ // CHECK: OpExtInst %[[V3F32T]] %{{.+}} native_exp %[[V3_ZERO_F32]]
%exp_v3_f32 = llvm.call @_Z22__spirv_ocl_native_expDv3_f32(%v3_c0_f32) : (vector<3xf32>) -> vector<3xf32>
- // CHECK: %{{.+}} = OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
+ // CHECK: OpExtInst %[[V4F64T]] %{{.+}} native_exp %[[V4_ZERO_F64]]
%exp_v4_f64 = llvm.call @_Z22__spirv_ocl_native_expDv4_f64(%v4_c0_f64) : (vector<4xf64>) -> vector<4xf64>
- // CHECK: %{{.+}} = OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
+ // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_exp %[[V8_ZERO_F64]]
%exp_v8_f64 = llvm.call @_Z22__spirv_ocl_native_expDv8_f64(%v8_c0_f64) : (vector<8xf64>) -> vector<8xf64>
- // CHECK: %{{.+}} = OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
+ // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_exp %[[V16_ZERO_F16]]
%exp_v16_f16 = llvm.call @_Z22__spirv_ocl_native_expDv16_f16(%v16_c0_f16) : (vector<16xf16>) -> vector<16xf16>
-
// SPIRV backend does not currently handle fastmath flags: The SPIRV
// backend would need to generate OpDecorate calls to decorate math ops
// with FPFastMathMode/FPFastMathModeINTEL decorations.
@@ -70,8 +69,30 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
// - https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_math_extended_instructions
// - https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpDecorate
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_cos %[[ZERO_F16]]
+ %cos_afn_f16 = llvm.call @_Z22__spirv_ocl_native_cosDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_exp2 %[[ZERO_F32]]
+ %exp2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_exp2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_log %[[ZERO_F16]]
+ %log_afn_f16 = llvm.call @_Z22__spirv_ocl_native_logDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_log2 %[[ZERO_F32]]
+ %log2_afn_f32 = llvm.call @_Z23__spirv_ocl_native_log2f(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[V8F64T]] %{{.+}} native_log10 %[[V8_ZERO_F64]]
+ %log10_afn_f64 = llvm.call @_Z24__spirv_ocl_native_log10Dv8_d(%v8_c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (vector<8xf64>) -> vector<8xf64>
+ // CHECK: OpExtInst %[[V16F16T]] %{{.+}} native_powr %[[V16_ZERO_F16]] %[[V16_ZERO_F16]]
+ %powr_afn_f16 = llvm.call @_Z23__spirv_ocl_native_powrDv16_DhS_(%v16_c0_f16, %v16_c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_rsqrt %[[ZERO_F64]]
+ %rsqrt_afn_f64 = llvm.call @_Z24__spirv_ocl_native_rsqrtd(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ // CHECK: OpExtInst %[[F16T]] %{{.+}} native_sin %[[ZERO_F16]]
+ %sin_afn_f16 = llvm.call @_Z22__spirv_ocl_native_sinDh(%c0_f16) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_sqrt %[[ZERO_F32]]
+ %sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+ // CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
+ %tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+
llvm.return
}
+
llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
@@ -80,7 +101,15 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
llvm.func @_Z22__spirv_ocl_native_expDv4_f64(vector<4xf64>) -> vector<4xf64>
llvm.func @_Z22__spirv_ocl_native_expDv8_f64(vector<8xf64>) -> vector<8xf64>
llvm.func @_Z22__spirv_ocl_native_expDv16_f16(vector<16xf16>) -> vector<16xf16>
-
-
+ llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ llvm.func @_Z24__spirv_ocl_native_log10Dv8_d(vector<8xf64>) -> vector<8xf64>
+ llvm.func @_Z23__spirv_ocl_native_powrDv16_DhS_(vector<16xf16>, vector<16xf16>) -> vector<16xf16>
+ llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
}
}
>From 0232c265a210231da53081548392247ab27017f3 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 14:59:40 -0700
Subject: [PATCH 07/13] Accomodate for arith.divf
---
.../mlir/Conversion/MathToXeVM/MathToXeVM.h | 2 +-
mlir/include/mlir/Conversion/Passes.td | 13 +++++++++++--
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 19 ++++++++++---------
.../Conversion/MathToXeVM/math-to-xevm.mlir | 13 ++++++++++++-
.../MathToXeVM/native-spirv-builtins.mlir | 5 ++++-
5 files changed, 38 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
index 7982aa3769e84..6bb69361dcb6d 100644
--- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -20,7 +20,7 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to XeVM calls.
-void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns);
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 20e0b95cc5c78..976e1b6b183e1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -801,10 +801,19 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
//===----------------------------------------------------------------------===//
def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
- let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this?
+ let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents";
let description = [{
- This pass converts supported Math ops to XeVM.
+ This pass converts supported math ops marked with the `afn` fastmath flag
+ to function calls for OpenCL `native_` math intrinsics: These intrinsics
+ are typically mapped directly to native device instructions, often resulting
+ in better performance. However, the precision/error of these intrinsics
+ are implementation-defined, and thus math ops are only converted when they
+ have the `afn` fastmath flag enabled.
}];
+ let options = [
+ Option<"convertArith", "convert-arith", "bool", /*default=*/"true",
+ "Convert supported Arith ops (e.g. arith.divf) as well.">
+ ];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 055cfdf064e4e..b75f8d3640a41 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -46,8 +46,6 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: OCL doesn't provide native int intrinsics, but check what happens
- // when IGC receives a native_exp on ints anyway
if (!isSPIRVCompatibleFloatOrVec(op.getType()))
return failure();
@@ -55,10 +53,11 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
return failure();
- // FIXME: Implement handling for vector sizes/dimensions that are not
- // supported by SPIRV.
SmallVector<Type, 1> operandTypes;
for (auto operand : adaptor.getOperands()) {
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vetor sizes are done in other blocking optimization passes.
if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
return failure();
operandTypes.push_back(operand.getType());
@@ -128,7 +127,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
OpBuilder b(parentFunc);
// Create a valid global location removing any metadata attached to the
- // location as debug info metadata inside of a function cannot be used
+ // location, as debug info metadata inside of a function cannot be used
// outside of that function.
auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
auto globalloc =
@@ -139,7 +138,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
const StringRef nativeFunc;
};
-void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) {
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
"__spirv_ocl_native_exp");
patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
@@ -162,22 +161,24 @@ void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
"__spirv_ocl_native_sqrt");
patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
"__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(patterns.getContext(),
+ "__spirv_ocl_native_divide");
}
namespace {
struct ConvertMathToXeVMPass
: public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
- ConvertMathToXeVMPass() = default;
+ using Base::Base;
void runOnOperation() override;
};
} // namespace
void ConvertMathToXeVMPass::runOnOperation() {
auto m = getOperation();
- // MLIRContext *ctx = m.getContext();
RewritePatternSet patterns(&getContext());
- populateMathToXeVMConversionPatterns(patterns);
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index 8e1d20dc94d78..ba5de228da411 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s
+// RUN: mlir-opt %s -convert-math-to-xevm \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-ARITH'
+// RUN: mlir-opt %s -convert-math-to-xevm='convert-arith=false' \
+// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
module @test_module {
// CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
@@ -23,6 +26,7 @@ module @test_module {
// CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
// CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
// CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
// CHECK-LABEL: func @math_ops
func.func @math_ops() {
@@ -142,6 +146,13 @@ module @test_module {
// CHECK: llvm.call @_Z22__spirv_ocl_native_tand(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
%tan_afn_f64 = math.tan %c1_f64 fastmath<afn> : f64
+ %c6_9_f32 = arith.constant 6.9 : f32
+ %c7_f32 = arith.constant 7. : f32
+
+ // CHECK-ARITH: llvm.call @_Z25__spirv_ocl_native_divideff(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+ // CHECK-NO-ARITH: arith.divf
+ %divf_afn_f32 = arith.divf %c6_9_f32, %c7_f32 fastmath<afn> : f32
+
return
}
}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
index 6bc90e34060b4..2492adafd6a50 100644
--- a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -89,10 +89,12 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
%sqrt_afn_f32 = llvm.call @_Z23__spirv_ocl_native_sqrtf(%c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
// CHECK: OpExtInst %[[F64T]] %{{.+}} native_tan %[[ZERO_F64]]
%tan_afn_f64 = llvm.call @_Z22__spirv_ocl_native_tand(%c0_f64) {fastmathFlags = #llvm.fastmath<afn>} : (f64) -> f64
+ // CHECK: OpExtInst %[[F32T]] %{{.+}} native_divide %[[ZERO_F32]] %[[ZERO_F32]]
+ %divide_afn_f32 = llvm.call @_Z25__spirv_ocl_native_divideff(%c0_f32, %c0_f32) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
llvm.return
}
-
+
llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
@@ -111,5 +113,6 @@ module @test_ocl_intrinsics attributes {gpu.container_module} {
llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
}
}
>From 3887fe5fe68d0f3caa4d1c2a850684ae97543140 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 25 Sep 2025 15:12:54 -0700
Subject: [PATCH 08/13] clang-format
---
.../mlir/Conversion/MathToXeVM/MathToXeVM.h | 3 +-
mlir/include/mlir/Conversion/Passes.td | 10 ++---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 37 ++++++++++---------
3 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
index 6bb69361dcb6d..91d3c92fd6296 100644
--- a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -20,7 +20,8 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
/// Populate the given list with patterns that convert from Math to XeVM calls.
-void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith);
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 976e1b6b183e1..5817babf68ddb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -801,7 +801,8 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
//===----------------------------------------------------------------------===//
def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
- let summary = "Convert (fast) math operations to native XeVM/SPIRV equivalents";
+ let summary =
+ "Convert (fast) math operations to native XeVM/SPIRV equivalents";
let description = [{
This pass converts supported math ops marked with the `afn` fastmath flag
to function calls for OpenCL `native_` math intrinsics: These intrinsics
@@ -810,10 +811,9 @@ def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
are implementation-defined, and thus math ops are only converted when they
have the `afn` fastmath flag enabled.
}];
- let options = [
- Option<"convertArith", "convert-arith", "bool", /*default=*/"true",
- "Convert supported Arith ops (e.g. arith.divf) as well.">
- ];
+ let options = [Option<
+ "convertArith", "convert-arith", "bool", /*default=*/"true",
+ "Convert supported Arith ops (e.g. arith.divf) as well.">];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index b75f8d3640a41..46833735a79dd 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
@@ -63,8 +63,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
operandTypes.push_back(operand.getType());
}
LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
- auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
- adaptor.getOperands());
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
@@ -138,32 +138,33 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
const StringRef nativeFunc;
};
-void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns, bool convertArith) {
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
"__spirv_ocl_native_exp");
patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
"__spirv_ocl_native_cos");
- patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(patterns.getContext(),
- "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
"__spirv_ocl_native_log");
- patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(patterns.getContext(),
- "__spirv_ocl_native_log2");
- patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(patterns.getContext(),
- "__spirv_ocl_native_log10");
- patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(patterns.getContext(),
- "__spirv_ocl_native_powr");
- patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(patterns.getContext(),
- "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
"__spirv_ocl_native_sin");
- patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(patterns.getContext(),
- "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
"__spirv_ocl_native_tan");
if (convertArith)
- patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(patterns.getContext(),
- "__spirv_ocl_native_divide");
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
}
namespace {
>From 31c911e12f99d15879d1b33cccb6a5c1c622a54c Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 26 Sep 2025 14:19:38 -0700
Subject: [PATCH 09/13] remove todos
---
mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 5 -----
mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir | 3 ---
2 files changed, 8 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
index 711c6876bb168..95aaba31a993e 100644
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -1,4 +1,3 @@
-# TODO check if everything here is needed
add_mlir_conversion_library(MLIRMathToXeVM
MathToXeVM.cpp
@@ -12,13 +11,9 @@ add_mlir_conversion_library(MLIRMathToXeVM
Core
LINK_LIBS PUBLIC
- MLIRDialectUtils
- MLIRFuncDialect
- MLIRGPUToGPURuntimeTransforms
MLIRMathDialect
MLIRLLVMCommonConversion
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
- MLIRVectorUtils
)
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index ba5de228da411..e1d3b2615e121 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -113,9 +113,6 @@ module @test_module {
// Check all other math operations:
- // native_divide(gentype x, gentype y)
- // TODO: convert arith.divf to arith/native_divide if option is enabled
-
// CHECK: llvm.call @_Z22__spirv_ocl_native_cosDh(%{{.*}}) {fastmathFlags = #llvm.fastmath<afn>} : (f16) -> f16
%cos_afn_f16 = math.cos %c1_f16 fastmath<afn> : f16
>From 20ac59571ab178cd669028a9dd0a1b8f23df598b Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:25:59 -0700
Subject: [PATCH 10/13] Address reviewer comments
---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 59 ++++++++-----------
.../Conversion/MathToXeVM/math-to-xevm.mlir | 42 ++++++-------
2 files changed, 44 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 46833735a79dd..0c1f9d39e72a2 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -12,6 +12,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -50,21 +51,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return failure();
arith::FastMathFlags fastFlags = op.getFastmath();
- if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
- return failure();
+ if (!(static_cast<uint32_t>(fastFlags) & static_cast<uint32_t>(arith::FastMathFlags::afn)))
+ return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
SmallVector<Type, 1> operandTypes;
for (auto operand : adaptor.getOperands()) {
// This pass only supports operations on vectors that are already in SPIRV
// supported vector sizes: Distributing unsupported vector sizes to SPIRV
- // supported vetor sizes are done in other blocking optimization passes.
+ // supported vector sizes are done in other blocking optimization passes.
if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
- return failure();
+ return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type");
operandTypes.push_back(operand.getType());
}
- LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+
+ auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+ auto funcOpRes =
+ LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType());
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, funcOp, adaptor.getOperands());
+ // Preserve the fastmath flags in our MLIR op for later use: We need to
+ // convert our MLIR fastmath attrs into something compatible with llvm.
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
@@ -90,49 +99,29 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return false;
}
- LLVM::LLVMFuncOp
- appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
- // This function assumes op types have already been validated using
- // isSPIRVCompatibleFloatOrVec.
- using LLVM::LLVMFuncOp;
- std::string mangledNativeFunc =
+ inline std::string getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ std::string mangledFuncName =
"_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
- auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+ auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
if (type.isF32())
- mangledNativeFunc += "f";
+ mangledFuncName += "f";
else if (type.isF16())
- mangledNativeFunc += "Dh";
+ mangledFuncName += "Dh";
else if (type.isF64())
- mangledNativeFunc += "d";
+ mangledFuncName += "d";
};
for (auto type : operandTypes) {
if (auto vecType = dyn_cast<VectorType>(type)) {
- mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
appendFloatToMangledFunc(vecType.getElementType());
} else
appendFloatToMangledFunc(type);
}
- auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
- auto funcOp =
- SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
- if (funcOp)
- return funcOp;
-
- auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
- assert(parentFunc && "expected there to be a parent function");
- OpBuilder b(parentFunc);
-
- // Create a valid global location removing any metadata attached to the
- // location, as debug info metadata inside of a function cannot be used
- // outside of that function.
- auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
- auto globalloc =
- op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
- return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+ return mangledFuncName;
}
const StringRef nativeFunc;
@@ -176,13 +165,11 @@ struct ConvertMathToXeVMPass
} // namespace
void ConvertMathToXeVMPass::runOnOperation() {
- auto m = getOperation();
-
RewritePatternSet patterns(&getContext());
populateMathToXeVMConversionPatterns(patterns, convertArith);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
- if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
index e1d3b2615e121..04b5906489d00 100644
--- a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -4,29 +4,29 @@
// RUN: | FileCheck %s -check-prefixes='CHECK,CHECK-NO-ARITH'
module @test_module {
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expf(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
//
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
- // CHECK: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv2_d(vector<2xf64>) -> vector<2xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv3_d(vector<3xf64>) -> vector<3xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_d(vector<4xf64>) -> vector<4xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv8_d(vector<8xf64>) -> vector<8xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_d(vector<16xf64>) -> vector<16xf64>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv16_f(vector<16xf32>) -> vector<16xf32>
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_expDv4_Dh(vector<4xf16>) -> vector<4xf16>
//
- // CHECK: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
- // CHECK: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
- // CHECK: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
- // CHECK: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
- // CHECK: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
- // CHECK: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
- // CHECK: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
- // CHECK: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
- // CHECK: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
- // CHECK: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
- // CHECK-ARITH: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_cosDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_exp2f(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_logDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_log2f(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_log10d(f64) -> f64
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_powrDhDh(f16, f16) -> f16
+ // CHECK-DAG: llvm.func @_Z24__spirv_ocl_native_rsqrtd(f64) -> f64
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_sinDh(f16) -> f16
+ // CHECK-DAG: llvm.func @_Z23__spirv_ocl_native_sqrtf(f32) -> f32
+ // CHECK-DAG: llvm.func @_Z22__spirv_ocl_native_tand(f64) -> f64
+ // CHECK-ARITH-DAG: llvm.func @_Z25__spirv_ocl_native_divideff(f32, f32) -> f32
// CHECK-LABEL: func @math_ops
func.func @math_ops() {
>From 7b8d0297f8c973d55c445b243295204497ef0d46 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:27:58 -0700
Subject: [PATCH 11/13] clang-format
---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 0c1f9d39e72a2..825a7bb79242a 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -51,7 +51,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return failure();
arith::FastMathFlags fastFlags = op.getFastmath();
- if (!(static_cast<uint32_t>(fastFlags) & static_cast<uint32_t>(arith::FastMathFlags::afn)))
+ if (!(static_cast<uint32_t>(fastFlags) &
+ static_cast<uint32_t>(arith::FastMathFlags::afn)))
return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
SmallVector<Type, 1> operandTypes;
@@ -60,13 +61,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
// supported vector sizes: Distributing unsupported vector sizes to SPIRV
// supported vector sizes are done in other blocking optimization passes.
if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
- return rewriter.notifyMatchFailure(op, "no equivalent native operation for operand type");
+ return rewriter.notifyMatchFailure(
+ op, "no equivalent native operation for operand type");
operandTypes.push_back(operand.getType());
}
auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
- auto funcOpRes =
- LLVM::lookupOrCreateFn(rewriter, moduleOp, getMangledNativeFuncName(operandTypes), operandTypes, op.getType());
+ auto funcOpRes = LLVM::lookupOrCreateFn(
+ rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+ operandTypes, op.getType());
assert(!failed(funcOpRes));
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
@@ -99,8 +102,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return false;
}
-
- inline std::string getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ inline std::string
+ getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
std::string mangledFuncName =
"_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
@@ -170,6 +173,7 @@ void ConvertMathToXeVMPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
- if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
>From 17ad71c766c02550b934d6ca9fa817e7787713c8 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Wed, 1 Oct 2025 12:42:48 -0700
Subject: [PATCH 12/13] improve comment
---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index 825a7bb79242a..bbac0af19f0b4 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -75,8 +75,9 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, funcOp, adaptor.getOperands());
- // Preserve the fastmath flags in our MLIR op for later use: We need to
- // convert our MLIR fastmath attrs into something compatible with llvm.
+ // Preserve fastmath flags in our MLIR op when converting to llvm function
+ // calls, in order to allow further fastmath optimizations: We thus need to
+ // convert arith fastmath attrs into attrs recognized by llvm.
arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
>From 203c1f08a1c2e54c3ae20ce3c18c6fe71e910239 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Thu, 2 Oct 2025 08:54:38 -0700
Subject: [PATCH 13/13] Improve logging
---
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index bbac0af19f0b4..156b9a38d07eb 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -57,13 +58,14 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
SmallVector<Type, 1> operandTypes;
for (auto operand : adaptor.getOperands()) {
+ Type opTy = operand.getType();
// This pass only supports operations on vectors that are already in SPIRV
// supported vector sizes: Distributing unsupported vector sizes to SPIRV
// supported vector sizes are done in other blocking optimization passes.
- if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+ if (!isSPIRVCompatibleFloatOrVec(opTy))
return rewriter.notifyMatchFailure(
- op, "no equivalent native operation for operand type");
- operandTypes.push_back(operand.getType());
+ op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+ operandTypes.push_back(opTy);
}
auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
More information about the Mlir-commits
mailing list