[Mlir-commits] [mlir] 8e18cdc - [mlir][x86vector] AVX10 I8 Dot Op (#178807)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 2 03:19:06 PST 2026


Author: Arun Thangamani
Date: 2026-02-02T16:49:00+05:30
New Revision: 8e18cdcb151eacfc23b28f9fa70458f57430a178

URL: https://github.com/llvm/llvm-project/commit/8e18cdcb151eacfc23b28f9fa70458f57430a178
DIFF: https://github.com/llvm/llvm-project/commit/8e18cdcb151eacfc23b28f9fa70458f57430a178.diff

LOG: [mlir][x86vector] AVX10 I8 Dot Op (#178807)

Adds AVX10 i8 dot-product operation and defines lowering to LLVM
intrinsics.

Target assembly instruction: `llvm.x86.avx10.vpdpbssd.512`

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir
    mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
    mlir/test/Target/LLVMIR/x86vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 468242d1c2780..cd16a36a6ee0b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,6 +343,58 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// AVX10 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX10_Op<string mnemonic, list<Trait> traits = []> :
+  Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
+
+//----------------------------------------------------------------------------//
+// AVX10 Int8 Dot
+//----------------------------------------------------------------------------//
+
+def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["w", "dst"]>,
+    TypesMatchWith<"`a` has four times elements as `w`",
+                   "w", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+                   "IntegerType::get($_self.getContext(), 8))">
+  ]> {
+  let summary = "AVX10 Dot Int8 op";
+  let description = [{
+    The `dot` op is an AVX10-Int8 specific op that can lower to the proper
+    LLVMAVX10-INT8 operation `llvm.vpdpbssd.512`.
+
+    Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+    corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+    results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
+    store the packed 32-bit results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
+                   VectorOfLengthAndType<[64], [I8]>:$a,
+                   VectorOfLengthAndType<[64], [I8]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[16], [I32]>:$dst);
+  let assemblyFormat =
+    "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+  let extraClassDeclaration = [{
+    std::string getIntrinsicName() {
+      return "llvm.x86.avx10.vpdpbssd.512";
+    }
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a00a3e5bdd766..89aa53307b95d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -113,10 +113,11 @@ struct VectorContractToPackedTypeDotProduct
                       "RHS) dim and acc dim of size 4/8/16.");
 
     if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
-        nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+        nonUnitDim != 8 && nonUnitDim != 16 &&
+        nonUnitDimAcc.front() == nonUnitDim)
       return rewriter.notifyMatchFailure(
           contractOp, "Int8 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8.");
+                      "RHS) dim and acc dim of size 4/8/16.");
 
     auto loc = contractOp.getLoc();
     auto castAcc = vector::ShapeCastOp::create(
@@ -159,10 +160,19 @@ struct VectorContractToPackedTypeDotProduct
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
-            castAcc, bitcastLhsPkType, castRhs);
+        if (nonUnitDimAcc.front() == 16) {
+          dp = x86vector::AVX10DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimRhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, bitcastLhsPkType, castRhs);
+        } else {
+          dp = x86vector::DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimRhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, bitcastLhsPkType, castRhs);
+        }
       }
     } else {
       auto castLhs = vector::ShapeCastOp::create(
@@ -192,10 +202,19 @@ struct VectorContractToPackedTypeDotProduct
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
-            castAcc, castLhs, bitcastRhsPkType);
+        if (nonUnitDimAcc.front() == 16) {
+          dp = x86vector::AVX10DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimLhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, castLhs, bitcastRhsPkType);
+        } else {
+          dp = x86vector::DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimLhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, castLhs, bitcastRhsPkType);
+        }
       }
     }
 

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 72dc899f4f0a6..6868b55095461 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,14 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 959177b27c7ea..672c32c9c3cc3 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,14 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>

diff  --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 65676cbae772c..e26a575e2bc90 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -68,6 +68,75 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_avx10int8dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_avx10int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x16x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_avx10int8dp_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x4xi8>
 !vecB = vector<1x1x8x4xi8>
 !vecC = vector<1x8xi32>
@@ -615,8 +684,8 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 !vecA = vector<1x1x1x4xi8>
-!vecB = vector<1x1x16x4xi8>
-!vecC = vector<1x1x16xi32>
+!vecB = vector<1x1x32x4xi8>
+!vecC = vector<1x1x32xi32>
 #map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
 #map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
 #map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
@@ -634,6 +703,7 @@ func.func @negative_wrong_vector_shape_int8(
 
 // CHECK-LABEL: @negative_wrong_vector_shape_int8
 // CHECK-NOT: x86vector.avx.dot.i8
+// CHECK-NOT: x86vector.avx10.dot.i8
 // CHECK: vector.contract
 
 module attributes {transform.with_named_sequence} {

diff  --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 74ae2424964b1..aad4a60720328 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,6 +109,14 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
+func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
 func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
   %a: memref<8xbf16>) -> vector<4xf32>


        


More information about the Mlir-commits mailing list