[Mlir-commits] [mlir] 587ba75 - [mlir][x86vector] AVX2 I8 Dot Op (#147908)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 04:19:10 PDT 2025
Author: arun-thmn
Date: 2025-07-11T13:19:07+02:00
New Revision: 587ba75a491d256500e4125c7c1de725c93fa84e
URL: https://github.com/llvm/llvm-project/commit/587ba75a491d256500e4125c7c1de725c93fa84e
DIFF: https://github.com/llvm/llvm-project/commit/587ba75a491d256500e4125c7c1de725c93fa84e.diff
LOG: [mlir][x86vector] AVX2 I8 Dot Op (#147908)
Adds AVX2 i8 dot-product operation and defines lowering to LLVM
intrinsics.
Target assembly instruction: `vpdpbssd.128/256`
Added:
Modified:
mlir/include/mlir/Dialect/X86Vector/X86Vector.td
mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
mlir/test/Dialect/X86Vector/roundtrip.mlir
mlir/test/Target/LLVMIR/x86vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 3bf0be0a716aa..73f6877c12fab 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -420,6 +420,62 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}
+//----------------------------------------------------------------------------//
+// AVX Int8 Dot
+//----------------------------------------------------------------------------//
+
+def DotInt8Op : AVX_Op<"dot.i8", [Pure,
+ X86IntrinsicOpInterface,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["w", "dst"]>,
+ TypesMatchWith<"`a` has four times elements as `w`",
+ "w", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+ "IntegerType::get($_self.getContext(), 8))">
+ ]> {
+ let summary = "Dot Int8 op";
+ let description = [{
+ The `dot` op is an AVX2-Int8 specific op that can lower to the proper
+ LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
+ vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with
+ corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit
+ results. Sum these 4 results with the corresponding 32-bit integer in `w`, and
+ store the packed 32-bit results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
+ VectorOfLengthAndType<[16, 32], [I8]>:$a,
+ VectorOfLengthAndType<[16, 32], [I8]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
+ let assemblyFormat =
+ "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+ let extraClassDeclaration = [{
+ std::string getIntrinsicName() {
+ std::string intr = "llvm.x86.avx2.vpdpbssd";
+ VectorType vecType = getW().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += "." + std::to_string(opBitWidth);
+ return intr;
+ }
+
+ SmallVector<Value> getIntrinsicOperands(
+ ::mlir::ArrayRef<Value> operands,
+ const ::mlir::LLVMTypeConverter &typeConverter,
+ ::mlir::RewriterBase &rewriter);
+ }];
+}
+
//----------------------------------------------------------------------------//
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index cc7ab7f3f3895..68aea48561283 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -86,6 +86,29 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
return intrinsicOperands;
}
+SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+ RewriterBase &rewriter) {
+ SmallVector<Value, 3> intrinsicOprnds;
+ Adaptor adaptor(operands, *this);
+ intrinsicOprnds.push_back(adaptor.getW());
+ // Bitcast `a` and `b` to i32
+ Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
+ getLoc(),
+ VectorType::get((getA().getType().getShape()[0] / 4),
+ rewriter.getIntegerType(32)),
+ adaptor.getA());
+ intrinsicOprnds.push_back(bitcast_a);
+ Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
+ getLoc(),
+ VectorType::get((getB().getType().getShape()[0] / 4),
+ rewriter.getIntegerType(32)),
+ adaptor.getB());
+ intrinsicOprnds.push_back(bitcast_b);
+
+ return intrinsicOprnds;
+}
+
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
RewriterBase &rewriter) {
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 63f06624ef897..72dc899f4f0a6 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 7dcab3eb4dcb8..959177b27c7ea 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index d11dc89bdc7c9..74ae2424964b1 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
return %0 : vector<8xf32>
}
+
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
+func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
+ %b: vector<16xi8>) -> vector<4xi32> {
+ // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+ return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
+func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
+ %b: vector<32xi8>) -> vector<8xi32> {
+ // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+ %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+ return %0 : vector<8xi32>
+}
More information about the Mlir-commits
mailing list