[Mlir-commits] [mlir] [mlir][x86vector] AVX10 I8 Dot Op (PR #178807)

Arun Thangamani llvmlistbot at llvm.org
Sun Feb 1 15:42:12 PST 2026


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/178807

>From bc56e3e5eb5ac2c17e7e1e8e5e3c9e1cf86814b0 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 17:14:40 -0800
Subject: [PATCH 1/8] MLIR dialect op map to llvm avx512 int8dp

---
 .../mlir/Dialect/X86Vector/X86Vector.td       | 52 +++++++++++++++++++
 .../Dialect/X86Vector/legalize-for-llvm.mlir  | 24 +++++++++
 mlir/test/Dialect/X86Vector/roundtrip.mlir    | 24 +++++++++
 mlir/test/Target/LLVMIR/x86vector.mlir        | 24 +++++++++
 4 files changed, 124 insertions(+)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 468242d1c2780..f7efb5316546f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,6 +343,58 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
   }];
 }
 
+//----------------------------------------------------------------------------//
+// AVX512 Int8 Dot
+//----------------------------------------------------------------------------//
+
+def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["w", "dst"]>,
+    TypesMatchWith<"`a` has four times elements as `w`",
+                   "w", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+                   "IntegerType::get($_self.getContext(), 8))">
+  ]> {
+  let summary = "AVX512 Dot Int8 op";
+  let description = [{
+    The `dot` op is an AVX512-Int8 specific op that can lower to the proper
+    LLVMAVX2-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+    vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+    corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+    results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
+    store the packed 32-bit results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [I32]>:$w,
+                   VectorOfLengthAndType<[16, 32, 64], [I8]>:$a,
+                   VectorOfLengthAndType<[16, 32, 64], [I8]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[4, 8, 16], [I32]>:$dst);
+  let assemblyFormat =
+    "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+  let extraClassDeclaration = [{
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.avx512.vpdpwssd";
+      VectorType vecType = getW().getType();
+      unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+      unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+      intr += "." + std::to_string(opBitWidth);
+      return intr;
+    }
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // AVX op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 72dc899f4f0a6..89606d2681bf8 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,30 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: func @avx512_dot_i8_128
+func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.128"
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_256
+func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.256"
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_512
+func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.512"
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 959177b27c7ea..4bd88d473a614 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,30 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: func @avx512_dot_i8_128
+func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_256
+func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_512
+func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
 func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
   %a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 74ae2424964b1..b5277b5bb96bb 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,6 +109,30 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
   return %0 : vector<16xbf16>
 }
 
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx512_vpdpwssd_128
+func.func @LLVM_x86_avx512_vpdpwssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx512_vpdpwssd_256
+func.func @LLVM_x86_avx512_vpdpwssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx512_vpdpwssd_512
+func.func @LLVM_x86_avx512_vpdpwssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+    %b: vector<64xi8>) -> vector<16xi32> {
+  // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.512(
+  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  return %0 : vector<16xi32>
+}
+
 // CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
 func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
   %a: memref<8xbf16>) -> vector<4xf32>

>From 40a2f4d477e1e1f80282acc41bf04efb3ebdfa90 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 18:42:56 -0800
Subject: [PATCH 2/8] correcting the target lowering for avx10 not avx512

---
 .../mlir/Dialect/X86Vector/X86Vector.td       | 32 +++++++++++--------
 .../Dialect/X86Vector/legalize-for-llvm.mlir  | 24 +++-----------
 mlir/test/Dialect/X86Vector/roundtrip.mlir    | 24 +++-----------
 mlir/test/Target/LLVMIR/x86vector.mlir        | 24 +++-----------
 4 files changed, 30 insertions(+), 74 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index f7efb5316546f..299fec9410b39 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,11 +343,19 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// AVX10 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX10_Op<string mnemonic, list<Trait> traits = []> :
+  Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
+
 //----------------------------------------------------------------------------//
-// AVX512 Int8 Dot
+// AVX10 Int8 Dot
 //----------------------------------------------------------------------------//
 
-def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
+def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
     X86IntrinsicOpInterface,
     AllTypesMatch<["a", "b"]>,
     AllTypesMatch<["w", "dst"]>,
@@ -356,10 +364,10 @@ def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
                    "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
                    "IntegerType::get($_self.getContext(), 8))">
   ]> {
-  let summary = "AVX512 Dot Int8 op";
+  let summary = "AVX10 Dot Int8 op";
   let description = [{
     The `dot` op is an AVX512-Int8 specific op that can lower to the proper
-    LLVMAVX2-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+    LLVMAVX10-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
     vectors it is applied to.
 
     #### From the Intel Intrinsics Guide:
@@ -371,24 +379,20 @@ def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
 
     Example:
     ```mlir
-    %dst = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+    %dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
     ```
   }];
-  let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [I32]>:$w,
-                   VectorOfLengthAndType<[16, 32, 64], [I8]>:$a,
-                   VectorOfLengthAndType<[16, 32, 64], [I8]>:$b
+  let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
+                   VectorOfLengthAndType<[64], [I8]>:$a,
+                   VectorOfLengthAndType<[64], [I8]>:$b
                    );
-  let results = (outs VectorOfLengthAndType<[4, 8, 16], [I32]>:$dst);
+  let results = (outs VectorOfLengthAndType<[16], [I32]>:$dst);
   let assemblyFormat =
     "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
 
   let extraClassDeclaration = [{
     std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.avx512.vpdpwssd";
-      VectorType vecType = getW().getType();
-      unsigned elemBitWidth = vecType.getElementTypeBitWidth();
-      unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
-      intr += "." + std::to_string(opBitWidth);
+      std::string intr = "llvm.x86.avx10.vpdpbssd.512";
       return intr;
     }
   }];
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 89606d2681bf8..6868b55095461 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,27 +95,11 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
-// CHECK-LABEL: func @avx512_dot_i8_128
-func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
-    %b: vector<16xi8>) -> vector<4xi32> {
-  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.128"
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
-  return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_256
-func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
-    %b: vector<32xi8>) -> vector<8xi32> {
-  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.256"
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_512
-func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
     %b: vector<64xi8>) -> vector<16xi32> {
-  // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.512"
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
   return %0 : vector<16xi32>
 }
 
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 4bd88d473a614..672c32c9c3cc3 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,27 +94,11 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
   return %0 : vector<16xbf16>
 }
 
-// CHECK-LABEL: func @avx512_dot_i8_128
-func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
-    %b: vector<16xi8>) -> vector<4xi32> {
-  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
-  return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_256
-func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
-    %b: vector<32xi8>) -> vector<8xi32> {
-  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_512
-func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
     %b: vector<64xi8>) -> vector<16xi32> {
-  // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  // CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
   return %0 : vector<16xi32>
 }
 
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index b5277b5bb96bb..43d7f6a13daaa 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,27 +109,11 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
   return %0 : vector<16xbf16>
 }
 
-// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx512_vpdpwssd_128
-func.func @LLVM_x86_avx512_vpdpwssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
-    %b: vector<16xi8>) -> vector<4xi32> {
-  // CHECK: call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
-  return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx512_vpdpwssd_256
-func.func @LLVM_x86_avx512_vpdpwssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
-    %b: vector<32xi8>) -> vector<8xi32> {
-  // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
-  return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx512_vpdpwssd_512
-func.func @LLVM_x86_avx512_vpdpwssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
+func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
     %b: vector<64xi8>) -> vector<16xi32> {
-  // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.512(
-  %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+  // CHECK: call <8 x i32> @llvm.x86.avx10.vpdpbssd.512(
+  %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
   return %0 : vector<16xi32>
 }
 

>From d241eec5f00f70db789d74aed92409e142f87a3b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 19:13:44 -0800
Subject: [PATCH 3/8] fix a typo in unit test

---
 mlir/test/Target/LLVMIR/x86vector.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 43d7f6a13daaa..aad4a60720328 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -112,7 +112,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
 // CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
 func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
     %b: vector<64xi8>) -> vector<16xi32> {
-  // CHECK: call <8 x i32> @llvm.x86.avx10.vpdpbssd.512(
+  // CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
   %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
   return %0 : vector<16xi32>
 }

>From 3ccd1810759996c3ebe2f365f1cab5094525f32e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 19:17:42 -0800
Subject: [PATCH 4/8] fix a typo in description

---
 mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 299fec9410b39..7a47f2876472a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -367,7 +367,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
   let summary = "AVX10 Dot Int8 op";
   let description = [{
     The `dot` op is an AVX512-Int8 specific op that can lower to the proper
-    LLVMAVX10-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+    LLVMAVX10-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
     vectors it is applied to.
 
     #### From the Intel Intrinsics Guide:

>From 5e413df209d816984d5c3d9bceaf83ad43149acb Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 30 Jan 2026 16:55:09 -0800
Subject: [PATCH 5/8] resolving few typos + string return

---
 mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 7a47f2876472a..4e2761e283150 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -366,11 +366,8 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
   ]> {
   let summary = "AVX10 Dot Int8 op";
   let description = [{
-    The `dot` op is an AVX512-Int8 specific op that can lower to the proper
-    LLVMAVX10-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
-    vectors it is applied to.
-
-    #### From the Intel Intrinsics Guide:
+    The `dot` op is an AVX10-Int8 specific op that can lower to the proper
+    LLVMAVX10-INT8 operation `llvm.vpdpbssdi.512`.
 
     Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
     corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
@@ -392,8 +389,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
 
   let extraClassDeclaration = [{
     std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.avx10.vpdpbssd.512";
-      return intr;
+      return "llvm.x86.avx10.vpdpbssd.512";
     }
   }];
 }

>From 4d2cac750d3c371e5ee96ff2b53fe35bf935063b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 30 Jan 2026 16:56:12 -0800
Subject: [PATCH 6/8] resolving few typos + string return

---
 mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4e2761e283150..cd16a36a6ee0b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -367,7 +367,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
   let summary = "AVX10 Dot Int8 op";
   let description = [{
     The `dot` op is an AVX10-Int8 specific op that can lower to the proper
-    LLVMAVX10-INT8 operation `llvm.vpdpbssdi.512`.
+    LLVMAVX10-INT8 operation `llvm.vpdpbssd.512`.
 
     Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
     corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit

>From 21eeb127047a3165aca9f65132b09ac966fc1f3a Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 1 Feb 2026 15:41:05 -0800
Subject: [PATCH 7/8] support for avx10 int8dp nano-kernel lowering.

---
 .../VectorContractToPackedTypeDotProduct.cpp  | 39 ++++++++++++++-----
 1 file changed, 29 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a00a3e5bdd766..89aa53307b95d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -113,10 +113,11 @@ struct VectorContractToPackedTypeDotProduct
                       "RHS) dim and acc dim of size 4/8/16.");
 
     if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
-        nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+        nonUnitDim != 8 && nonUnitDim != 16 &&
+        nonUnitDimAcc.front() == nonUnitDim)
       return rewriter.notifyMatchFailure(
           contractOp, "Int8 dot-product operation expects non-unit (LHR or "
-                      "RHS) dim and acc dim of size 4/8.");
+                      "RHS) dim and acc dim of size 4/8/16.");
 
     auto loc = contractOp.getLoc();
     auto castAcc = vector::ShapeCastOp::create(
@@ -159,10 +160,19 @@ struct VectorContractToPackedTypeDotProduct
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
-            castAcc, bitcastLhsPkType, castRhs);
+        if (nonUnitDimAcc.front() == 16) {
+          dp = x86vector::AVX10DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimRhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, bitcastLhsPkType, castRhs);
+        } else {
+          dp = x86vector::DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimRhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, bitcastLhsPkType, castRhs);
+        }
       }
     } else {
       auto castLhs = vector::ShapeCastOp::create(
@@ -192,10 +202,19 @@ struct VectorContractToPackedTypeDotProduct
       }
 
       if (lhsTy.getElementType().isSignlessInteger(8)) {
-        dp = x86vector::DotInt8Op::create(
-            rewriter, loc,
-            VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
-            castAcc, castLhs, bitcastRhsPkType);
+        if (nonUnitDimAcc.front() == 16) {
+          dp = x86vector::AVX10DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimLhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, castLhs, bitcastRhsPkType);
+        } else {
+          dp = x86vector::DotInt8Op::create(
+              rewriter, loc,
+              VectorType::get(nonUnitDimLhs.front(),
+                              rewriter.getIntegerType(32)),
+              castAcc, castLhs, bitcastRhsPkType);
+        }
       }
     }
 

>From d8363d5c5e6e4e202c6561cdc56257a717148531 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 1 Feb 2026 15:41:33 -0800
Subject: [PATCH 8/8] support for avx10 int8dp nano-kernel lowering.

---
 ...or-contract-to-packed-type-dotproduct.mlir | 103 ++++++++++++------
 1 file changed, 69 insertions(+), 34 deletions(-)

diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 65676cbae772c..a2231fc97993d 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -68,6 +68,75 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_avx10int8dp(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_avx10int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x16x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x16x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_avx10int8dp_bcst_B(
+  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+  %0 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %arg0, %arg1, %arg2
+    : !vecA, !vecB into !vecC
+  return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x1x4xi8>
 !vecB = vector<1x1x8x4xi8>
 !vecC = vector<1x8xi32>
@@ -614,40 +683,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-!vecA = vector<1x1x1x4xi8>
-!vecB = vector<1x1x16x4xi8>
-!vecC = vector<1x1x16xi32>
-#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
-#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_wrong_vector_shape_int8(
-  %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
-{
-  %0 = vector.contract {
-    indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>}
-    %arg0, %arg1, %arg2
-    : !vecA, !vecB into !vecC
-  return %0 : !vecC
-}
-
-// CHECK-LABEL: @negative_wrong_vector_shape_int8
-// CHECK-NOT: x86vector.avx.dot.i8
-// CHECK: vector.contract
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
-    } : !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 !vecA = vector<1x1x1x2xbf16>
 !vecB = vector<1x1x32x2xbf16>
 !vecC = vector<1x1x32xf32>



More information about the Mlir-commits mailing list