[Mlir-commits] [mlir] 20fe74e - [mlir][x86] Hardware extension namespaces (#184392)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 07:02:38 PDT 2026


Author: Adam Siemieniuk
Date: 2026-03-23T15:02:34+01:00
New Revision: 20fe74ebe19c844079f7093eeb945994274737a1

URL: https://github.com/llvm/llvm-project/commit/20fe74ebe19c844079f7093eeb945994274737a1
DIFF: https://github.com/llvm/llvm-project/commit/20fe74ebe19c844079f7093eeb945994274737a1.diff

LOG: [mlir][x86] Hardware extension namespaces (#184392)

Adds hardware extension C++ namespaces to X86 dialect op definitions to
match their IR mnemonic extensions.

All X86 dialect ops are updated to follow the scheme first introduced
with AMX ops i.e., 'x86::{ext}::{op_name}'.
Nested namespaces improve source code readability by explicitly
indicating which hardware extension each operation requires, and it
aligns naming scheme between code and IR.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/X86/X86.td
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/lib/Dialect/X86/IR/X86Dialect.cpp
    mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
    mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
    mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td
index e8965d04c2145..814bf884395bb 100644
--- a/mlir/include/mlir/Dialect/X86/X86.td
+++ b/mlir/include/mlir/Dialect/X86/X86.td
@@ -35,8 +35,10 @@ def X86_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 // Operation that is part of the input dialect.
-class AVX512_Op<string mnemonic, list<Trait> traits = []> :
-  Op<X86_Dialect, "avx512." # mnemonic, traits> {}
+class AVX512_Op<string mnemonic, list<Trait> traits = []>
+    : Op<X86_Dialect, "avx512." # mnemonic, traits> {
+  let cppNamespace = X86_Dialect.cppNamespace # "::avx512";
+}
 
 //----------------------------------------------------------------------------//
 // MaskCompressOp
@@ -351,8 +353,10 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
 //===----------------------------------------------------------------------===//
 
 // Operation that is part of the input dialect.
-class AVX10_Op<string mnemonic, list<Trait> traits = []> :
-  Op<X86_Dialect, "avx10." # mnemonic, traits> {}
+class AVX10_Op<string mnemonic, list<Trait> traits = []>
+    : Op<X86_Dialect, "avx10." # mnemonic, traits> {
+  let cppNamespace = X86_Dialect.cppNamespace # "::avx10";
+}
 
 //----------------------------------------------------------------------------//
 // AVX10 Int8 Dot
@@ -403,14 +407,18 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
 //===----------------------------------------------------------------------===//
 
 // Operation that is part of the input dialect.
-class AVX_Op<string mnemonic, list<Trait> traits = []> :
-  Op<X86_Dialect, "avx." # mnemonic, traits> {}
+class AVX_Op<string mnemonic, list<Trait> traits = []>
+  : Op<X86_Dialect, "avx." # mnemonic, traits> {
+  let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
 
 // Operation that may be part of the input dialect, but whose
 // form is somewhere between the user view of the operation
 // and the actual lower level intrinsic in LLVM IR.
-class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
-  Op<X86_Dialect, "avx.intr." # mnemonic, traits> {}
+class AVX_LowOp<string mnemonic, list<Trait> traits = []>
+  : Op<X86_Dialect, "avx.intr." # mnemonic, traits> {
+  let cppNamespace = X86_Dialect.cppNamespace # "::avx";
+}
 
 //----------------------------------------------------------------------------//
 // AVX Rsqrt

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index e1a3fc1c83fa3..9b367e0f84cb8 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1740,7 +1740,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
   // Compute an approximate result.
   Value yApprox = handleMultidimensionalVectors(
       builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
-        return x86::RsqrtOp::create(builder, operands);
+        return x86::avx::RsqrtOp::create(builder, operands);
       });
 
   // Do a single step of Newton-Raphson iteration to improve the approximation.

diff  --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
index 47ee5d272a890..b186652aaa866 100644
--- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
+++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp
@@ -46,7 +46,7 @@ static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
   return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
 }
 
-LogicalResult x86::MaskCompressOp::verify() {
+LogicalResult x86::avx512::MaskCompressOp::verify() {
   if (getSrc() && getConstantSrc())
     return emitError("cannot use both src and constant_src");
 
@@ -60,7 +60,7 @@ LogicalResult x86::MaskCompressOp::verify() {
   return success();
 }
 
-SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
+SmallVector<Value> x86::avx512::MaskCompressOp::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
   auto loc = getLoc();
@@ -82,9 +82,9 @@ SmallVector<Value> x86::MaskCompressOp::getIntrinsicOperands(
 }
 
 SmallVector<Value>
-x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
-                                 const LLVMTypeConverter &typeConverter,
-                                 RewriterBase &rewriter) {
+x86::avx::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                      const LLVMTypeConverter &typeConverter,
+                                      RewriterBase &rewriter) {
   SmallVector<Value> intrinsicOperands(operands);
   // Dot product of all elements, broadcasted to all elements.
   Value scale =
@@ -94,7 +94,7 @@ x86::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
   return intrinsicOperands;
 }
 
-SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::BcstToPackedF32Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
   Adaptor adaptor(operands, *this);
@@ -102,7 +102,7 @@ SmallVector<Value> x86::BcstToPackedF32Op::getIntrinsicOperands(
                            typeConverter, rewriter)};
 }
 
-SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
   Adaptor adaptor(operands, *this);
@@ -110,7 +110,7 @@ SmallVector<Value> x86::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
                            typeConverter, rewriter)};
 }
 
-SmallVector<Value> x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
+SmallVector<Value> x86::avx::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
   Adaptor adaptor(operands, *this);

diff  --git a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
index 6cd58da8024c2..92a0209b91c1f 100644
--- a/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/ShuffleVectorFMAOps.cpp
@@ -24,10 +24,10 @@ namespace {
 // Validates whether the given operation is an x86 operation and has only
 // one consumer.
 static bool validateFMAOperands(Value op) {
-  if (auto cvt = op.getDefiningOp<x86::CvtPackedEvenIndexedToF32Op>())
+  if (auto cvt = op.getDefiningOp<x86::avx::CvtPackedEvenIndexedToF32Op>())
     return cvt.getResult().hasOneUse();
 
-  if (auto bcst = op.getDefiningOp<x86::BcstToPackedF32Op>())
+  if (auto bcst = op.getDefiningOp<x86::avx::BcstToPackedF32Op>())
     return bcst.getResult().hasOneUse();
 
   return false;
@@ -42,8 +42,8 @@ static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
   Value lhs = fmaOp.getLhs();
   Value rhs = fmaOp.getRhs();
 
-  if (!isa<x86::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
-      !isa<x86::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
+  if (!isa<x86::avx::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
+      !isa<x86::avx::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
     return false;
 
   if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
@@ -150,9 +150,10 @@ struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
       if (!fma)
         continue;
 
-      bool hasX86CvtOperand =
-          isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getLhs().getDefiningOp()) ||
-          isa<x86::CvtPackedEvenIndexedToF32Op>(fma.getRhs().getDefiningOp());
+      bool hasX86CvtOperand = isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+                                  fma.getLhs().getDefiningOp()) ||
+                              isa<x86::avx::CvtPackedEvenIndexedToF32Op>(
+                                  fma.getRhs().getDefiningOp());
 
       if (hasX86CvtOperand && stopAtNextDependentFMA)
         break;

diff  --git a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
index 598f3462a7e2a..287892c3a660f 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractBF16ToFMA.cpp
@@ -446,10 +446,11 @@ struct VectorContractBF16ToFMA
           VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
           contractOp.getAcc());
 
-      auto loadBcstBF16ElementToF32 = x86::BcstToPackedF32Op::create(
+      auto loadBcstBF16ElementToF32 = x86::avx::BcstToPackedF32Op::create(
           rewriter, loc, dstType, unitDimSubview[0]);
-      auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
-          rewriter, loc, dstType, nonUnitDimSubview[0]);
+      auto loadEvenIdxElementF32 =
+          x86::avx::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
+                                                        nonUnitDimSubview[0]);
       auto evenIdxFMA =
           vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
                                 loadEvenIdxElementF32, castAcc);
@@ -467,7 +468,7 @@ struct VectorContractBF16ToFMA
                           accTyPairCont.getElementType()),
           pairContractOp.getAcc());
 
-      auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+      auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
           rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
       auto oddIdxFMA = vector::FMAOp::create(
           rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
@@ -480,18 +481,18 @@ struct VectorContractBF16ToFMA
     }
 
     // Load, broadcast, and do FMA for odd indexed BF16 elements.
-    auto loadBcstOddIdxElementToF32 = x86::BcstToPackedF32Op::create(
+    auto loadBcstOddIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[0]);
-    auto loadOddIdxElementF32 = x86::CvtPackedOddIndexedToF32Op::create(
+    auto loadOddIdxElementF32 = x86::avx::CvtPackedOddIndexedToF32Op::create(
         rewriter, loc, dstType, nonUnitDimSubview[0]);
     auto oddIdxFMA =
         vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
                               loadOddIdxElementF32, castAcc);
 
     // Load, broadcast, and do FMA for even indexed BF16 elements.
-    auto loadBcstEvenIdxElementToF32 = x86::BcstToPackedF32Op::create(
+    auto loadBcstEvenIdxElementToF32 = x86::avx::BcstToPackedF32Op::create(
         rewriter, loc, dstType, unitDimSubview[1]);
-    auto loadEvenIdxElementF32 = x86::CvtPackedEvenIndexedToF32Op::create(
+    auto loadEvenIdxElementF32 = x86::avx::CvtPackedEvenIndexedToF32Op::create(
         rewriter, loc, dstType, nonUnitDimSubview[0]);
     vector::FMAOp fma =
         vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,

diff  --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index b47eede2a9156..cdf0c3925d6a3 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -384,7 +384,7 @@ struct VectorContractToPackedTypeDotProduct
         rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
 
     if (lhsTy.getElementType().isBF16()) {
-      dp = x86::DotBF16Op::create(
+      dp = x86::avx512::DotBF16Op::create(
           rewriter, loc,
           VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
           bitcastUnitDimPkType, castNonUnitDim);
@@ -392,12 +392,12 @@ struct VectorContractToPackedTypeDotProduct
 
     if (lhsTy.getElementType().isSignlessInteger(8)) {
       if (nonUnitDimAcc.front() == 16) {
-        dp = x86::AVX10DotInt8Op::create(
+        dp = x86::avx10::AVX10DotInt8Op::create(
             rewriter, loc,
             VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
             castAcc, bitcastUnitDimPkType, castNonUnitDim);
       } else {
-        dp = x86::DotInt8Op::create(
+        dp = x86::avx::DotInt8Op::create(
             rewriter, loc,
             VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
             castAcc, bitcastUnitDimPkType, castNonUnitDim);


        


More information about the Mlir-commits mailing list