[Mlir-commits] [mlir] 8e18cdc - [mlir][x86vector] AVX10 I8 Dot Op (#178807)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 2 03:19:06 PST 2026
Author: Arun Thangamani
Date: 2026-02-02T16:49:00+05:30
New Revision: 8e18cdcb151eacfc23b28f9fa70458f57430a178
URL: https://github.com/llvm/llvm-project/commit/8e18cdcb151eacfc23b28f9fa70458f57430a178
DIFF: https://github.com/llvm/llvm-project/commit/8e18cdcb151eacfc23b28f9fa70458f57430a178.diff
LOG: [mlir][x86vector] AVX10 I8 Dot Op (#178807)
Adds AVX10 i8 dot-product operation and defines lowering to LLVM
intrinsics.
Target assembly instruction: `llvm.x86.avx10.vpdpbssd.512`
Added:
Modified:
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
mlir/test/Dialect/X86Vector/roundtrip.mlir
mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
mlir/test/Target/LLVMIR/x86vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 468242d1c2780..cd16a36a6ee0b 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -343,6 +343,58 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
}];
}
+//===----------------------------------------------------------------------===//
+// AVX10 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX10_Op<string mnemonic, list<Trait> traits = []> :
+ Op<X86Vector_Dialect, "avx10." # mnemonic, traits> {}
+
+//----------------------------------------------------------------------------//
+// AVX10 Int8 Dot
+//----------------------------------------------------------------------------//
+
+def AVX10DotInt8Op : AVX10_Op<"dot.i8", [Pure,
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["w", "dst"]>,
+ TypesMatchWith<"`a` has four times elements as `w`",
+ "w", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+ "IntegerType::get($_self.getContext(), 8))">
+ ]> {
+ let summary = "AVX10 Dot Int8 op";
+ let description = [{
+ The `dot` op is an AVX10-Int8 specific op that can lower to the proper
+ LLVMAVX10-INT8 operation `llvm.vpdpbssd.512`.
+
+ Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+ corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+ results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
+ store the packed 32-bit results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$w,
+ VectorOfLengthAndType<[64], [I8]>:$a,
+ VectorOfLengthAndType<[64], [I8]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[16], [I32]>:$dst);
+ let assemblyFormat =
+ "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
+ return "llvm.x86.avx10.vpdpbssd.512";
+ }
+ }];
+}
+
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a00a3e5bdd766..89aa53307b95d 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -113,10 +113,11 @@ struct VectorContractToPackedTypeDotProduct
"RHS) dim and acc dim of size 4/8/16.");
if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
- nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
+ nonUnitDim != 8 && nonUnitDim != 16 &&
+ nonUnitDimAcc.front() == nonUnitDim)
return rewriter.notifyMatchFailure(
contractOp, "Int8 dot-product operation expects non-unit (LHR or "
- "RHS) dim and acc dim of size 4/8.");
+ "RHS) dim and acc dim of size 4/8/16.");
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
@@ -159,10 +160,19 @@ struct VectorContractToPackedTypeDotProduct
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
- castAcc, bitcastLhsPkType, castRhs);
+ if (nonUnitDimAcc.front() == 16) {
+ dp = x86vector::AVX10DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ } else {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimRhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, bitcastLhsPkType, castRhs);
+ }
}
} else {
auto castLhs = vector::ShapeCastOp::create(
@@ -192,10 +202,19 @@ struct VectorContractToPackedTypeDotProduct
}
if (lhsTy.getElementType().isSignlessInteger(8)) {
- dp = x86vector::DotInt8Op::create(
- rewriter, loc,
- VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
- castAcc, castLhs, bitcastRhsPkType);
+ if (nonUnitDimAcc.front() == 16) {
+ dp = x86vector::AVX10DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ } else {
+ dp = x86vector::DotInt8Op::create(
+ rewriter, loc,
+ VectorType::get(nonUnitDimLhs.front(),
+ rewriter.getIntegerType(32)),
+ castAcc, castLhs, bitcastRhsPkType);
+ }
}
}
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 72dc899f4f0a6..6868b55095461 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,14 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx10.vpdpbssd.512"
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 959177b27c7ea..672c32c9c3cc3 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,14 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avx10_dot_i8_512
+func.func @avx10_dot_i8_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: x86vector.avx10.dot.i8 {{.*}} : vector<64xi8> -> vector<16xi32>
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: memref<8xbf16>) -> vector<4xf32>
diff --git a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
index 65676cbae772c..e26a575e2bc90 100644
--- a/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86Vector/vector-contract-to-packed-type-dotproduct.mlir
@@ -68,6 +68,75 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x1x1x4xi8>
+!vecB = vector<1x1x16x4xi8>
+!vecC = vector<1x16xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d1, d2)>
+func.func @brgemm_to_avx10int8dp(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+// CHECK-LABEL: @brgemm_to_avx10int8dp
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x16x1x4xi8>
+!vecB = vector<1x1x1x4xi8>
+!vecC = vector<1x16x1xi32>
+#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
+#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
+#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
+func.func @batch_matmul_avx10int8dp_bcst_B(
+ %arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
+{
+ %0 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "reduction", "parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %arg0, %arg1, %arg2
+ : !vecA, !vecB into !vecC
+ return %0 : !vecC
+}
+
+
+// CHECK-LABEL: @batch_matmul_avx10int8dp_bcst_B
+// CHECK: vector.broadcast
+// CHECK: x86vector.avx10.dot.i8
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x1x4xi8>
!vecB = vector<1x1x8x4xi8>
!vecC = vector<1x8xi32>
@@ -615,8 +684,8 @@ module attributes {transform.with_named_sequence} {
// -----
!vecA = vector<1x1x1x4xi8>
-!vecB = vector<1x1x16x4xi8>
-!vecC = vector<1x1x16xi32>
+!vecB = vector<1x1x32x4xi8>
+!vecC = vector<1x1x32xi32>
#map = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d3, d4)>
#map1 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d3, d2, d4)>
#map2 = affine_map<(d0, d4, d1, d2, d3) -> (d0, d1, d2)>
@@ -634,6 +703,7 @@ func.func @negative_wrong_vector_shape_int8(
// CHECK-LABEL: @negative_wrong_vector_shape_int8
// CHECK-NOT: x86vector.avx.dot.i8
+// CHECK-NOT: x86vector.avx10.dot.i8
// CHECK: vector.contract
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 74ae2424964b1..aad4a60720328 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -109,6 +109,14 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: define <16 x i32> @LLVM_x86_avx10_vpdpbssd_512
+func.func @LLVM_x86_avx10_vpdpbssd_512(%w: vector<16xi32>, %a: vector<64xi8>,
+ %b: vector<64xi8>) -> vector<16xi32> {
+ // CHECK: call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(
+ %0 = x86vector.avx10.dot.i8 %w, %a, %b : vector<64xi8> -> vector<16xi32>
+ return %0 : vector<16xi32>
+}
+
// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: memref<8xbf16>) -> vector<4xf32>
More information about the Mlir-commits
mailing list