[Mlir-commits] [mlir] 20fe74e - [mlir][x86] Hardware extension namespaces (#184392)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 07:02:38 PDT 2026
Author: Adam Siemieniuk
Date: 2026-03-23T15:02:34+01:00
New Revision: 20fe74ebe19c844079f7093eeb945994274737a1
URL: https://github.com/llvm/llvm-project/commit/20fe74ebe19c844079f7093eeb945994274737a1
DIFF: https://github.com/llvm/llvm-project/commit/20fe74ebe19c844079f7093eeb945994274737a1.diff
LOG: [mlir][x86] Hardware extension namespaces (#184392)
Adds hardware extension C++ namespaces to X86 dialect op definitions to
match their IR mnemonic extensions.
All X86 dialect ops are updated to follow the scheme first introduced
with AMX ops i.e., 'x86::{ext}::{op_name}'.
Nested namespaces improve source code readability by explicitly
indicating which hardware extension each operation requires, and it
aligns naming scheme between code and IR.
Added:
Modified:
mlir/include/mlir/Dialect/X86/X86.td
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
mlir/lib/Dialect/X86/IR/X86Dialect.cpp
mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index e8965d04c2145..814bf884395bb 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -35,8 +35,10 @@ def X86_Dialect : Dialect {
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX512_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx512." # mnemonic, traits> {}
+class AVX512_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx512." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx512";
+}
//----------------------------------------------------------------------------//
// MaskCompressOp
@@ -351,8 +353,10 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX10_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx10." # mnemonic, traits> {}
+class AVX10_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx10." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx10";
+}
//----------------------------------------------------------------------------//
// AVX10 Int8 Dot
@@ -403,14 +407,18 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
//===----------------------------------------------------------------------===//
// Operation that is part of the input dialect.
-class AVX_Op<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx." # mnemonic, traits> {}
+class AVX_Op<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
// Operation that may be part of the input dialect, but whose
// form is somewhere between the user view of the operation
// and the actual lower level intrinsic in LLVM IR.
-class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
- Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
+class AVX_LowOp<string mnemonic, list<Trait> traits = []>
+ : Op<X86_Dialect, "avx.intr." # mnemonic, traits> {
+ let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
//----------------------------------------------------------------------------//
// AVX Rsqrt
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index e1a3fc1c83fa3..9b367e0f84cb8 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
// Compute an approximate result.
Value yApprox = handleMultidimensionalVectors(
builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
- return x86::RsqrtOp::create(builder, operands);
+ return x86::avx::RsqrtOp::create(builder, operands);
});
// Do a single step of Newton-Raphson iteration to improve the approximation.
diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index 47ee5d272a890..b186652aaa866 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -46,7 +46,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
}
-LogicalResult x86::MaskCompressOp::verify() {
+LogicalResult x86::avx512::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
@@ -60,7 +60,7 @@ LogicalResult x86::MaskCompressOp::verify() {
return success();
}
-SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
+SmallVector<Value> x86::avx512::MaskCompressOp::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
auto loc = getLoc();
@@ -82,9 +82,9 @@ SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
}
SmallVector<Value>
-x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
- const LLVMTypeConverter &typeConverter,
- RewriterBase &rewriter) {
+x86::avx::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+ const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
SmallVector<Value> intrinsicOperands(operands);
// Dot product of all elements, broadcasted to all elements.
Value scale =
@@ -94,7 +94,7 @@ x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
return intrinsicOperands;
}
-SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::BcstToPackedF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -102,7 +102,7 @@ SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
-SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
@@ -110,7 +110,7 @@ SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
typeConverter, rewriter)};
}
-SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
Adaptor adaptor(operands, *this);
diff --git a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
index 6cd58da8024c2..92a0209b91c1f 100644
--- a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
@@ -24,10 +24,10 @@ namespace {
// Validates whether the given operation is an x86 operation and has only
// one consumer.
static bool validateFMAOperands(Value op) {
- if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
+ if (auto cvt = op.getDefiningOp<x86::avx::CvtPackedEvenIndexedToF32Op>())
return cvt.getResult().hasOneUse();
- if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
+ if (auto bcst = op.getDefiningOp<x86::avx::BcstToPackedF32Op>())
return bcst.getResult().hasOneUse();
return false;
@@ -42,8 +42,8 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
Value lhs = fmaOp.getLhs();
Value rhs = fmaOp.getRhs();
- if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
- !isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+ if (!isa<x86::avx::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+ !isa<x86::avx::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
return false;
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
@@ -150,9 +150,10 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
if (!fma)
continue;
- bool hasX86CvtOperand =
- isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
- isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
+ bool hasX86CvtOperand = isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+ fma.getLhs().getDefiningOp()) ||
+ isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+ fma.getRhs().getDefiningOp());
if (hasX86CvtOperand && stopAtNextDependentFMA)
break;
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
index 598f3462a7e2a..287892c3a660f 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
@@ -446,10 +446,11 @@ struct VectorContractBF16ToFMA
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());
- auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstBF16ElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
- auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
- rewriter, loc, dstType, nonUnitDimSubview[0]);
+ auto loadEvenIdxElementF32 =
+ x86::avx::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+ nonUnitDimSubview[0]);
auto evenIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
loadEvenIdxElementF32, castAcc);
@@ -467,7 +468,7 @@ struct VectorContractBF16ToFMA
accTyPairCont.getElementType()),
pairContractOp.getAcc());
- auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+ auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA = vector::FMAOp::create(
rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
@@ -480,18 +481,18 @@ struct VectorContractBF16ToFMA
}
// Load, broadcast, and do FMA for odd indexed BF16 elements.
- auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstOddIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[0]);
- auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+ auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
auto oddIdxFMA =
vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
loadOddIdxElementF32, castAcc);
// Load, broadcast, and do FMA for even indexed BF16 elements.
- auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
+ auto loadBcstEvenIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
rewriter, loc, dstType, unitDimSubview[1]);
- auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
+ auto loadEvenIdxElementF32 = x86::avx::CvtPackedEvenIndexedToF32Op::create(
rewriter, loc, dstType, nonUnitDimSubview[0]);
vector::FMAOp fma =
vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index b47eede2a9156..cdf0c3925d6a3 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
if (lhsTy.getElementType().isBF16()) {
- dp = x86::DotBF16Op::create(
+ dp = x86::avx512::DotBF16Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
bitcastUnitDimPkType, castNonUnitDim);
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
if (lhsTy.getElementType().isSignlessInteger(8)) {
if (nonUnitDimAcc.front() == 16) {
- dp = x86::AVX10DotInt8Op::create(
+ dp = x86::avx10::AVX10DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
} else {
- dp = x86::DotInt8Op::create(
+ dp = x86::avx::DotInt8Op::create(
rewriter, loc,
VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
castAcc, bitcastUnitDimPkType, castNonUnitDim);
More information about the Mlir-commits
mailing list