[Mlir-commits] [mlir] [mlir][x86vector] AVX10 I8 Dot Op (PR #178807)
Arun Thangamani
llvmlistbot at llvm.org
Sun Feb 1 15:42:12 PST 2026
https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/178807
>From bc56e3e5eb5ac2c17e7e1e8e5e3c9e1cf86814b0 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 17:14:40 -0800
Subject: [PATCH 1/8] MLIR dialect op map to llvm avx512 int8dp
---
.../mlir/Dialect/X86Vector/X86Vector.td | 52 +++++++++++++++++++
.../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +++++++++
mlir/test/Dialect/X86Vector/roundtrip.mlir | 24 +++++++++
mlir/test/Target/LLVMIR/x86vector.mlir | 24 +++++++++
4 files changed, 124 insertions(+)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 468242d1c2780..f7efb5316546f 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,6 +343,58 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
}];
}
+//----------------------------------------------------------------------------//
+// AVX512 Int8 Dot
+//----------------------------------------------------------------------------//
+
+def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["w", "dst"]>,
+ TypesMatchWith<"`a` has four times elements as `w`",
+ "w", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+ "IntegerType::get($_self.getContext(), 8))">
+ ]> {
+ let summary = "AVX512 Dot Int8 op";
+ let description = [{
+ The `dot` op is an AVX512-Int8 specific op that can lower to the proper
+ LLVMAVX2-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+ vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+ corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+ results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
+ store the packed 32-bit results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [I32]>:$w,
+ VectorOfLengthAndType<[16, 32, 64], [I8]>:$a,
+ VectorOfLengthAndType<[16, 32, 64], [I8]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[4, 8, 16], [I32]>:$dst);
+ let assemblyFormat =
+ "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
+ std::string intr = "llvm.x86.avx512.vpdpwssd";
+ VectorType vecType = getW().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += "." + std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 72dc899f4f0a6..89606d2681bf8 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,30 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avx512_dot_i8_128
+func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.128"
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_256
+func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.256"
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_512
+func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.512"
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 959177b27c7ea..4bd88d473a614 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,30 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avx512_dot_i8_128
+func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_256
+func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func @avx512_dot_i8_512
+func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 74ae2424964b1..b5277b5bb96bb 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,6 +109,30 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx512_vpdpwssd_128
+func.func @LLVM_x86_avx512_vpdpwssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx512_vpdpwssd_256
+func.func @LLVM_x86_avx512_vpdpwssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx512_vpdpwssd_512
+func.func @LLVM_x86_avx512_vpdpwssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.512(
+ %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
>From 40a2f4d477e1e1f80282acc41bf04efb3ebdfa90 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 18:42:56 -0800
Subject: [PATCH 2/8] correcting the target lowering for avx10 not avx512
---
.../mlir/Dialect/X86Vector/X86Vector.td | 32 +++++++++++--------
.../Dialect/X86Vector/legalize-for-llvm.mlir | 24 +++-----------
mlir/test/Dialect/X86Vector/roundtrip.mlir | 24 +++-----------
mlir/test/Target/LLVMIR/x86vector.mlir | 24 +++-----------
4 files changed, 30 insertions(+), 74 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index f7efb5316546f..299fec9410b39 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,11 +343,19 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
}];
}
+//===----------------------------------------------------------------------===//
+// AVX10 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX10_Op<string mnemonic, list<Trait> traits = []> :
+ Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
+
//----------------------------------------------------------------------------//
-// AVX512 Int8 Dot
+// AVX10 Int8 Dot
//----------------------------------------------------------------------------//
-def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
+def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
X86IntrinsicOpInterface,
AllTypesMatch<["a", "b"]>,
AllTypesMatch<["w", "dst"]>,
@@ -356,10 +364,10 @@ def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
"IntegerType::get($_self.getContext(), 8))">
]> {
- let summary = "AVX512 Dot Int8 op";
+ let summary = "AVX10 Dot Int8 op";
let description = [{
The `dot` op is an AVX512-Int8 specific op that can lower to the proper
- LLVMAVX2-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+ LLVMAVX10-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
vectors it is applied to.
#### From the Intel Intrinsics Guide:
@@ -371,24 +379,20 @@ def AVX512DotInt8Op : AVX512_Op<"dot.i8", [Pure,
Example:
```mlir
- %dst = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ %dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
```
}];
- let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [I32]>:$w,
- VectorOfLengthAndType<[16, 32, 64], [I8]>:$a,
- VectorOfLengthAndType<[16, 32, 64], [I8]>:$b
+ let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
+ VectorOfLengthAndType<[64], [I8]>:$a,
+ VectorOfLengthAndType<[64], [I8]>:$b
);
- let results = (outs VectorOfLengthAndType<[4, 8, 16], [I32]>:$dst);
+ let results = (outs VectorOfLengthAndType<[16], [I32]>:$dst);
let assemblyFormat =
"$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
let extraClassDeclaration = [{
std::string getIntrinsicName() {
- std::string intr = "llvm.x86.avx512.vpdpwssd";
- VectorType vecType = getW().getType();
- unsigned elemBitWidth = vecType.getElementTypeBitWidth();
- unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
- intr += "." + std::to_string(opBitWidth);
+ std::string intr = "llvm.x86.avx10.vpdpbssd.512";
return intr;
}
}];
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 89606d2681bf8..6868b55095461 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,27 +95,11 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
-// CHECK-LABEL: func @avx512_dot_i8_128
-func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
- %b: vector<16xi8>) -> vector<4xi32> {
- // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.128"
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
- return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_256
-func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
- %b: vector<32xi8>) -> vector<8xi32> {
- // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.256"
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_512
-func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
- // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vpdpwssd.512"
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 4bd88d473a614..672c32c9c3cc3 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,27 +94,11 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
-// CHECK-LABEL: func @avx512_dot_i8_128
-func.func @avx512_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
- %b: vector<16xi8>) -> vector<4xi32> {
- // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
- return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_256
-func.func @avx512_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
- %b: vector<32xi8>) -> vector<8xi32> {
- // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func @avx512_dot_i8_512
-func.func @avx512_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
- // CHECK: x86vector.avx512.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ // CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index b5277b5bb96bb..43d7f6a13daaa 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,27 +109,11 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}
-// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx512_vpdpwssd_128
-func.func @LLVM_x86_avx512_vpdpwssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
- %b: vector<16xi8>) -> vector<4xi32> {
- // CHECK: call <4 x i32> @llvm.x86.avx512.vpdpwssd.128(
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
- return %0 : vector<4xi32>
-}
-
-// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx512_vpdpwssd_256
-func.func @LLVM_x86_avx512_vpdpwssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
- %b: vector<32xi8>) -> vector<8xi32> {
- // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.256(
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx512_vpdpwssd_512
-func.func @LLVM_x86_avx512_vpdpwssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
+func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
- // CHECK: call <8 x i32> @llvm.x86.avx512.vpdpwssd.512(
- %0 = x86vector.avx512.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ // CHECK: call <8 x i32> @llvm.x86.avx10.vpdpbssd.512(
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
>From d241eec5f00f70db789d74aed92409e142f87a3b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 19:13:44 -0800
Subject: [PATCH 3/8] fix a typo in unit test
---
mlir/test/Target/LLVMIR/x86vector.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 43d7f6a13daaa..aad4a60720328 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -112,7 +112,7 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
%b: vector<64xi8>) -> vector<16xi32> {
- // CHECK: call <8 x i32> @llvm.x86.avx10.vpdpbssd.512(
+ // CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
%0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
return %0 : vector<16xi32>
}
>From 3ccd1810759996c3ebe2f365f1cab5094525f32e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 29 Jan 2026 19:17:42 -0800
Subject: [PATCH 4/8] fix a typo in description
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 299fec9410b39..7a47f2876472a 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -367,7 +367,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
let summary = "AVX10 Dot Int8 op";
let description = [{
The `dot` op is an AVX512-Int8 specific op that can lower to the proper
- LLVMAVX10-INT8 operation `llvm.vpdpwssd` depending on the width of MLIR
+ LLVMAVX10-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
vectors it is applied to.
#### From the Intel Intrinsics Guide:
>From 5e413df209d816984d5c3d9bceaf83ad43149acb Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 30 Jan 2026 16:55:09 -0800
Subject: [PATCH 5/8] resolving few typos + string return
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 7a47f2876472a..4e2761e283150 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -366,11 +366,8 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
]> {
let summary = "AVX10 Dot Int8 op";
let description = [{
- The `dot` op is an AVX512-Int8 specific op that can lower to the proper
- LLVMAVX10-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
- vectors it is applied to.
-
- #### From the Intel Intrinsics Guide:
+ The `dot` op is an AVX10-Int8 specific op that can lower to the proper
+ LLVMAVX10-INT8 operation `llvm.vpdpbssdi.512`.
Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
@@ -392,8 +389,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
let extraClassDeclaration = [{
std::string getIntrinsicName() {
- std::string intr = "llvm.x86.avx10.vpdpbssd.512";
- return intr;
+ return "llvm.x86.avx10.vpdpbssd.512";
}
}];
}
>From 4d2cac750d3c371e5ee96ff2b53fe35bf935063b Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 30 Jan 2026 16:56:12 -0800
Subject: [PATCH 6/8] resolving few typos + string return
---
mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4e2761e283150..cd16a36a6ee0b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -367,7 +367,7 @@ def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
let summary = "AVX10 Dot Int8 op";
let description = [{
The `dot` op is an AVX10-Int8 specific op that can lower to the proper
- LLVMAVX10-INT8 operation `llvm.vpdpbssdi.512`.
+ LLVMAVX10-INT8 operation `llvm.vpdpbssd.512`.
Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
>From 21eeb127047a3165aca9f65132b09ac966fc1f3a Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 1 Feb 2026 15:41:05 -0800
Subject: [PATCH 7/8] support for avx10 int8dp nano-kernel lowering.
---
.../VectorContractToPackedTypeDotProduct.cpp | 39 ++++++++++++++-----
1 file changed, 29 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a00a3e5bdd766..89aa53307b95d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -113,10 +113,11 @@ struct VectorContractToPackedTypeDotProduct
"RHS) dim and acc dim of size 4/8/16.");
if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
- nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+ nonUnitDim != 8 && nonUnitDim != 16 &&
+ nonUnitDimAcc.front() == nonUnitDim)
return rewriter.notifyMatchFailure(
contractOp, "Int8 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8.");
+ "RHS) dim and acc dim of size 4/8/16.");
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
@@ -159,10 +160,19 @@ struct VectorContractToPackedTypeDotProduct
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
- castAcc, bitcastLhsPkType, castRhs);
+ if (nonUnitDimAcc.front() == 16) {
+ dp = x86vector::AVX10DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ } else {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
}
} else {
auto castLhs = vector::ShapeCastOp::create(
@@ -192,10 +202,19 @@ struct VectorContractToPackedTypeDotProduct
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
- castAcc, castLhs, bitcastRhsPkType);
+ if (nonUnitDimAcc.front() == 16) {
+ dp = x86vector::AVX10DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ } else {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
}
}
>From d8363d5c5e6e4e202c6561cdc56257a717148531 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Sun, 1 Feb 2026 15:41:33 -0800
Subject: [PATCH 8/8] support for avx10 int8dp nano-kernel lowering.
---
...or-contract-to-packed-type-dotproduct.mlir | 103 ++++++++++++------
1 file changed, 69 insertions(+), 34 deletions(-)
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 65676cbae772c..a2231fc97993d 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -68,6 +68,75 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_avx10int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_avx10int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x16x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_avx10int8dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x4xi8>
!vecB = vector<1x1x8x4xi8>
!vecC = vector<1x8xi32>
@@ -614,40 +683,6 @@ module attributes {transform.with_named_sequence} {
// -----
-!vecA = vector<1x1x1x4xi8>
-!vecB = vector<1x1x16x4xi8>
-!vecC = vector<1x1x16xi32>
-#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
-#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
-#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
-func.func @negative_wrong_vector_shape_int8(
- %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
-{
- %0 = vector.contract {
- indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
- kind = #vector.kind<add>}
- %arg0, %arg1, %arg2
- : !vecA, !vecB into !vecC
- return %0 : !vecC
-}
-
-// CHECK-LABEL: @negative_wrong_vector_shape_int8
-// CHECK-NOT: x86vector.avx.dot.i8
-// CHECK: vector.contract
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_patterns to %func {
- transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
- } : !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
!vecA = vector<1x1x1x2xbf16>
!vecB = vector<1x1x32x2xbf16>
!vecC = vector<1x1x32xf32>
More information about the Mlir-commits
mailing list