[Mlir-commits] [mlir] [mlir][x86] Hardware extension namespaces (PR #184392)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Mar 23 06:29:53 PDT 2026
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/184392
>From 29a0491f0f37ebe5f75a7841441b741603fe070f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 3 Mar 2026 18:14:25 +0100
Subject: [PATCH] [mlir][x86] Hardware extension namespaces
Adds hardware extension C++ namespaces to X86 dialect ops definitions
to match their IR mnemonic extensions.
Nested namespaces improve source code readability by explicitly
indicating which hardware extension each operation requires, and it
aligns naming scheme between code and IR.
---
mlir/include/mlir/Dialect/X86/X86.td | 24 ++++++++++++-------
.../Transforms/PolynomialApproximation.cpp | 2 +-
mlir/lib/Dialect/X86/IR/X86Dialect.cpp | 16 ++++++-------
.../X86/Transforms/ShuffleVectorFMAOps.cpp | 15 ++++++------
.../Transforms/VectorContractBF16ToFMA.cpp | 17 ++++++-------
.../VectorContractToPackedTypeDotProduct.cpp | 6 ++---
6 files changed, 45 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index e8965d04c2145..814bf884395bb 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -35,8 +35,10 @@ def X86_Dialect : Dialect {
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX512_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx512." # mnemonic, traits> {}
+class AVX512_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx512." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx512";
+}
//----------------------------------------------------------------------------//
// MaskCompressOp
@@ -351,8 +353,10 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX10_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx10." # mnemonic, traits> {}
+class AVX10_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx10." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx10";
+}
//----------------------------------------------------------------------------//
// AVX10 Int8 Dot
@@ -403,14 +407,18 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx." # mnemonic, traits> {}
+class AVX_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
// Operation that may be part of the input dialect, but whose
// form is somewhere between the user view of the operation
// and the actual lower level intrinsic in LLVM IR.
-class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
+class AVX_LowOp<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx.intr." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
//----------------------------------------------------------------------------//
// AVX Rsqrt
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index e1a3fc1c83fa3..9b367e0f84cb8 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
// Compute an approximate result.
Value yApprox = handleMultidimensionalVectors(
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
- return x86::RsqrtOp::create(builder, operands);
+ return x86::avx::RsqrtOp::create(builder, operands);
});
// Do a single step of Newton-Raphson iteration to improve the approximation.
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index 47ee5d272a890..b186652aaa866 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -46,7 +46,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
}
-LogicalResult x86::MaskCompressOp::verify() {
+LogicalResult x86::avx512::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
@@ -60,7 +60,7 @@ LogicalResult x86::MaskCompressOp::verify() {
return success();
}
-SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
+SmallVector<Value> x86::avx512::MaskCompressOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
@@ -82,9 +82,9 @@ SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
}
SmallVector<Value>
-x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
- const LLVMTypeConverter &typeConverter,
- RewriterBase &rewriter) {
+x86::avx::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements.
Value scale =
@@ -94,7 +94,7 @@ x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
return intrinsicOperands;
}
-SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::BcstToPackedF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -102,7 +102,7 @@ SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
-SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -110,7 +110,7 @@ SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
-SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
diff --git a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
index 6cd58da8024c2..92a0209b91c1f 100644
--- a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
@@ -24,10 +24,10 @@ namespace {
// Validates whether the given operation is an x86 operation and has only
// one consumer.
static bool validateFMAOperands(Value op) {
- if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
+ if (auto cvt = op.getDefiningOp<x86::avx::CvtPackedEvenIndexedToF32Op>())
return cvt.getResult().hasOneUse();
- if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
+ if (auto bcst = op.getDefiningOp<x86::avx::BcstToPackedF32Op>())
return bcst.getResult().hasOneUse();
return false;
@@ -42,8 +42,8 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
Value lhs = fmaOp.getLhs();
Value rhs = fmaOp.getRhs();
- if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
- !isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+ if (!isa<x86::avx::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+ !isa<x86::avx::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
return false;
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
@@ -150,9 +150,10 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
if (!fma)
continue;
- bool hasX86CvtOperand =
- isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
- isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
+ bool hasX86CvtOperand = isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+ fma.getLhs().getDefiningOp()) ||
+ isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+ fma.getRhs().getDefiningOp());
if (hasX86CvtOperand && stopAtNextDependentFMA)
break;
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
index 598f3462a7e2a..287892c3a660f 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
@@ -446,10 +446,11 @@ struct VectorContractBF16ToFMA
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
- auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstBF16ElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
- auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
- rewriter, loc, dstType, nonUnitDimSubview[0]);
+ auto loadEvenIdxElementF32 =
+ x86::avx::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
auto evenIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
loadEvenIdxElementF32, castAcc);
@@ -467,7 +468,7 @@ struct VectorContractBF16ToFMA
accTyPairCont.getElementType()),
pairContractOp.getAcc());
- auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+ auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA = vector::FMAOp::create(
rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
@@ -480,18 +481,18 @@ struct VectorContractBF16ToFMA
}
// Load, broadcast, and do FMA for odd indexed BF16 elements.
- auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstOddIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
- auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+ auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
loadOddIdxElementF32, castAcc);
// Load, broadcast, and do FMA for even indexed BF16 elements.
- auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstEvenIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]);
- auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
+ auto loadEvenIdxElementF32 = x86::avx::CvtPackedEvenIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
vector::FMAOp fma =
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index b47eede2a9156..cdf0c3925d6a3 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
if (lhsTy.getElementType().isBF16()) {
- dp = x86::DotBF16Op::create(
+ dp = x86::avx512::DotBF16Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
bitcastUnitDimPkType, castNonUnitDim);
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isSignlessInteger(8)) {
if (nonUnitDimAcc.front() == 16) {
- dp = x86::AVX10DotInt8Op::create(
+ dp = x86::avx10::AVX10DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
} else {
- dp = x86::DotInt8Op::create(
+ dp = x86::avx::DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
More information about the Mlir-commits
mailing list