[Mlir-commits] [mlir] 587ba75 - [mlir][x86vector] AVX2 I8 Dot Op (#147908)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 11 04:19:10 PDT 2025


Author: arun-thmn
Date: 2025-07-11T13:19:07+02:00
New Revision: 587ba75a491d256500e4125c7c1de725c93fa84e

URL: https://github.com/llvm/llvm-project/commit/587ba75a491d256500e4125c7c1de725c93fa84e
DIFF: https://github.com/llvm/llvm-project/commit/587ba75a491d256500e4125c7c1de725c93fa84e.diff

LOG: [mlir][x86vector] AVX2 I8 Dot Op (#147908)

Adds AVX2 i8 dot-product operation and defines lowering to LLVM
intrinsics.

Target assembly instruction: `vpdpbssd.128/256`

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/X86Vector/X86Vector.td
    mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
    mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
    mlir/test/Dialect/X86Vector/roundtrip.mlir
    mlir/test/Target/LLVMIR/x86vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 3bf0be0a716aa..73f6877c12fab 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -420,6 +420,62 @@ def DotOp : AVX_LowOp<"dot", [Pure,
   }];
 }
 
+//----------------------------------------------------------------------------//
+// AVX Int8 Dot
+//----------------------------------------------------------------------------//
+
+def DotInt8Op : AVX_Op<"dot.i8", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["w", "dst"]>,
+    TypesMatchWith<"`a` has four times elements as `w`",
+                   "w", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+                   "IntegerType::get($_self.getContext(), 8))">
+  ]> {
+  let summary = "Dot Int8 op";
+  let description = [{
+    The `dot` op is an AVX2-Int8 specific op that can lower to the proper
+    LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
+    vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with 
+    corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit 
+    results. Sum these 4 results with the corresponding 32-bit integer in `w`, and 
+    store the packed 32-bit results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
+                   VectorOfLengthAndType<[16, 32], [I8]>:$a,
+                   VectorOfLengthAndType<[16, 32], [I8]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
+  let assemblyFormat =
+    "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
+
+  let extraClassDeclaration = [{
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.avx2.vpdpbssd";
+      VectorType vecType = getW().getType();
+      unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+      unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+      intr += "." + std::to_string(opBitWidth);
+      return intr;
+    }
+
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
+  }];
+}
+
 //----------------------------------------------------------------------------//
 // AVX: Convert BF16/F16 to F32 and broadcast into packed F32
 //----------------------------------------------------------------------------//

diff  --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index cc7ab7f3f3895..68aea48561283 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -86,6 +86,29 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
   return intrinsicOperands;
 }
 
+SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  SmallVector<Value, 3> intrinsicOprnds;
+  Adaptor adaptor(operands, *this);
+  intrinsicOprnds.push_back(adaptor.getW());
+  // Bitcast `a` and `b` to i32
+  Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
+      getLoc(),
+      VectorType::get((getA().getType().getShape()[0] / 4),
+                      rewriter.getIntegerType(32)),
+      adaptor.getA());
+  intrinsicOprnds.push_back(bitcast_a);
+  Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
+      getLoc(),
+      VectorType::get((getB().getType().getShape()[0] / 4),
+                      rewriter.getIntegerType(32)),
+      adaptor.getB());
+  intrinsicOprnds.push_back(bitcast_b);
+
+  return intrinsicOprnds;
+}
+
 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {

diff  --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 63f06624ef897..72dc899f4f0a6 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}

diff  --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 7dcab3eb4dcb8..959177b27c7ea 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}

diff  --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index d11dc89bdc7c9..74ae2424964b1 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
+func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
+func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}


        


More information about the Mlir-commits mailing list