[Mlir-commits] [mlir] [MLIR][SPIRV][XeVM] Add support for fastmath `afn` option using native OpenCL intrinsics (PR #159878)
Ian Li
llvmlistbot at llvm.org
Mon Sep 22 08:18:57 PDT 2025
https://github.com/ianayl updated https://github.com/llvm/llvm-project/pull/159878
>From 8af837573e9325d5f9f1e38bd0e4c510a58d37a1 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Fri, 19 Sep 2025 16:44:38 -0700
Subject: [PATCH 1/3] Initial mockup of the MathToXeVM pass
---
.../mlir/Conversion/MathToXeVM/MathToXeVM.h | 26 +++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 17 ++
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 24 +++
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 159 ++++++++++++++++++
.../Conversion/MathToXeVM/math-to-xevm.mlir | 22 +++
.../MathToXeVM/native-spirv-builtins.mlir | 33 ++++
8 files changed, 283 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
create mode 100644 mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
create mode 100644 mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
create mode 100644 mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
diff --git a/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
new file mode 100644
index 0000000000000..7982aa3769e84
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToXeVM/MathToXeVM.h
@@ -0,0 +1,26 @@
+//===- MathToXeVM.h - Utils for converting Math to XeVM -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+#define MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to XeVM calls.
+void populateMathToXeVMConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOXEVM_MATHTOXEVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..ead4d5c50046d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -84,6 +84,7 @@
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..20e0b95cc5c78 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -796,6 +796,23 @@ def ConvertMathToSPIRVPass : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
+//===----------------------------------------------------------------------===//
+// MathToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToXeVM : Pass<"convert-math-to-xevm", "ModuleOp"> {
+ let summary = "Convert Math dialect to XeVM"; // TODO: what do I call this?
+ let description = [{
+ This pass converts supported Math ops to XeVM.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "xevm::XeVMDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..bebf1b8fff3f9 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..3f389359a6a2c
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,24 @@
+// TODO check if everything here is needed
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000000000..e18350219ffe8
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,159 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+// GPUCommon/OpToFunctionCallLowering is not used here, as it doesn't handle
+// native functions/intrinsics that take vector operands.
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO: OCL doesn't provide native int intrinsics, but check what happens
+ // when IGC receives a native_exp on ints anyway
+ // TODO: what about vectorization?
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+ return failure();
+
+ // FIXME: Implement handling for vector sizes/dimensions that are not
+ // supported by SPIRV
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ if (!isSPIRVCompatibleFloatOrVec(operand.getType()))
+ return failure();
+ operandTypes.push_back(operand.getType());
+ }
+ LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat()) {
+ return true;
+ } else if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+ // This function assumes op types have already been validated using
+ // isSPIRVCompatibleFloatOrVec.
+ using LLVM::LLVMFuncOp;
+
+ std::string mangledNativeFunc =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledNativeFunc](Type type) {
+ if (type.isF32())
+ mangledNativeFunc += "f";
+ else if (type.isF16())
+ mangledNativeFunc += "Dh";
+ else if (type.isF64())
+ mangledNativeFunc += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledNativeFunc += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+ if (funcOp)
+ return funcOp;
+
+ auto parentFunc = op->template getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ // Create a valid global location removing any metadata attached to the
+ // location as debug info metadata inside of a function cannot be used
+ // outside of that function.
+ auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
+ auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+ return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
+ }
+
+ const StringRef nativeFunc;
+};
+
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
+ patterns.getContext(), "__spirv_ocl_native_exp");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ ConvertMathToXeVMPass() = default;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ auto m = getOperation();
+ //MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
new file mode 100644
index 0000000000000..436d0e0941b9e
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/math-to-xevm.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -convert-math-to-xevm | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ // CHECK: llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ // CHECK-LABEL: func @math_ops
+ func.func @math_ops(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expDh(%{{.*}}) : (f16) -> f16
+ %result16 = math.exp %arg_f16 fastmath<fast> : f16
+
+ // CHECK: llvm.call @_Z22__spirv_ocl_native_expd(%{{.*}}) : (f64) -> f64
+ %result64 = math.exp %arg_f64 fastmath<afn> : f64
+
+ // CHECK: math.exp
+ %result_no_fast = math.exp %arg_f64 : f64
+
+ // TODO check fastmath<none>
+
+ func.return %result16, %result64 : f16, f64
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
new file mode 100644
index 0000000000000..f762f4b60f818
--- /dev/null
+++ b/mlir/test/Conversion/MathToXeVM/native-spirv-builtins.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -gpu-module-to-binary="format=isa" \
+// RUN: -debug-only=serialize-to-isa 2> %t
+// RUN: FileCheck --input-file=%t %s
+//
+// MathToXeVM pass generates OpenCL intrinsics function calls when converting
+// Math ops with `fastmath` attr to native function calls. It is assumed that
+// the SPIRV backend would correctly convert these intrinsics calls to OpenCL
+// ExtInst instructions in SPIRV (See llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp).
+//
+// To ensure this assumption holds, this test verifies that the SPIRV backend
+// behaves as expected.
+
+module @test_ocl_intrinsics attributes {gpu.container_module} {
+ gpu.module @kernel [#xevm.target] {
+ llvm.func spir_kernelcc @native_fcns() attributes {gpu.kernel} {
+ // CHECK-DAG: %[[F16T:.+]] = OpTypeFloat 16
+ // CHECK-DAG: %[[ZERO_F16:.+]] = OpConstantNull %[[F16T]]
+ %c0_f16 = llvm.mlir.constant(0. : f16) : f16
+ // CHECK-DAG: %[[F64T:.+]] = OpTypeFloat 64
+ // CHECK-DAG: %[[ZERO_F64:.+]] = OpConstantNull %[[F64T]]
+ %c0_f64 = llvm.mlir.constant(0. : f64) : f64
+
+ // CHECK: %{{.+}} = OpExtInst %[[F16T]] %{{.+}} native_exp %[[ZERO_F16]]
+ %exp_f16 = llvm.call @_Z22__spirv_ocl_native_expDh(%c0_f16) : (f16) -> f16
+ // CHECK: %{{.+}} = OpExtInst %[[F64T]] %{{.+}} native_exp %[[ZERO_F64]]
+ %exp_f64 = llvm.call @_Z22__spirv_ocl_native_expd(%c0_f64) : (f64) -> f64
+
+ llvm.return
+ }
+ llvm.func @_Z22__spirv_ocl_native_expDh(f16) -> f16
+ llvm.func @_Z22__spirv_ocl_native_expd(f64) -> f64
+ }
+}
>From 59160d682d89a019178cf1e74433e29a5e3a75ad Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:10:29 -0700
Subject: [PATCH 2/3] clang-format
---
mlir/include/mlir/Conversion/Passes.h | 2 +-
mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp | 28 +++++++++++--------
2 files changed, 17 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index ead4d5c50046d..40d866ec7bf10 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
@@ -84,7 +85,6 @@
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
-#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
namespace mlir {
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
index e18350219ffe8..e1f1205d6efaa 100644
--- a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -38,8 +38,9 @@ using namespace mlir;
template <typename Op>
struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
- ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc, PatternBenefit benefit = 1)
- : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
@@ -51,7 +52,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
return failure();
arith::FastMathFlags fastFlags = op.getFastmath();
- if (!((uint32_t) fastFlags & (uint32_t) arith::FastMathFlags::afn))
+ if (!((uint32_t)fastFlags & (uint32_t)arith::FastMathFlags::afn))
return failure();
// FIXME: Implement handling for vector sizes/dimensions that are not
@@ -63,7 +64,8 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
operandTypes.push_back(operand.getType());
}
LLVM::LLVMFuncOp funcOp = appendOrGetFuncOp(op, operandTypes);
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp, adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, funcOp,
+ adaptor.getOperands());
return success();
}
@@ -79,13 +81,15 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
if (shape.size() != 1)
return false;
// SPIRV only allows vectors of size 2, 3, 4, 8, 16.
- if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 || shape[0] == 16)
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
return true;
}
return false;
}
- LLVM::LLVMFuncOp appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
+ LLVM::LLVMFuncOp
+ appendOrGetFuncOp(Op &op, const SmallVector<Type, 1> &operandTypes) const {
// This function assumes op types have already been validated using
// isSPIRVCompatibleFloatOrVec.
using LLVM::LLVMFuncOp;
@@ -112,7 +116,7 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
auto funcAttr = StringAttr::get(op->getContext(), mangledNativeFunc);
auto funcOp =
- SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
+ SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
if (funcOp)
return funcOp;
@@ -124,17 +128,17 @@ struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
// location as debug info metadata inside of a function cannot be used
// outside of that function.
auto funcType = LLVM::LLVMFunctionType::get(op.getType(), operandTypes);
- auto globalloc = op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
+ auto globalloc =
+ op->getLoc()->template findInstanceOfOrUnknown<FileLineColLoc>();
return LLVMFuncOp::create(b, globalloc, mangledNativeFunc, funcType);
}
const StringRef nativeFunc;
};
-
void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(
- patterns.getContext(), "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
}
namespace {
@@ -147,7 +151,7 @@ struct ConvertMathToXeVMPass
void ConvertMathToXeVMPass::runOnOperation() {
auto m = getOperation();
- //MLIRContext *ctx = m.getContext();
+ // MLIRContext *ctx = m.getContext();
RewritePatternSet patterns(&getContext());
populateMathToXeVMConversionPatterns(patterns);
>From 89f94eadf57b3aeb6fb416ed48f1fc7ba1cf1476 Mon Sep 17 00:00:00 2001
From: Ian Li <ian.li at intel.com>
Date: Mon, 22 Sep 2025 08:18:45 -0700
Subject: [PATCH 3/3] fix cmake
---
mlir/lib/Conversion/MathToXeVM/CMakeLists.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
index 3f389359a6a2c..711c6876bb168 100644
--- a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -1,4 +1,4 @@
-// TODO check if everything here is needed
+# TODO check if everything here is needed
add_mlir_conversion_library(MLIRMathToXeVM
MathToXeVM.cpp
More information about the Mlir-commits
mailing list