[Mlir-commits] [mlir] [mlir][x86vector] AVX2 I8 Dot Op (PR #147908)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 11 04:05:37 PDT 2025


https://github.com/arun-thmn updated https://github.com/llvm/llvm-project/pull/147908

>From bf558772085f53fc3ef3a2722aefaf5fd4c4bcb1 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 10 Jul 2025 00:40:53 -0700
Subject: [PATCH 1/5] MLIR support for VPDPBSSD instruction through llvm
 instrincs

---
 .../mlir/Dialect/X86Vector/X86Vector.td       | 53 +++++++++++++++++++
 .../Dialect/X86Vector/legalize-for-llvm.mlir  | 16 ++++++
 mlir/test/Dialect/X86Vector/roundtrip.mlir    | 16 ++++++
 mlir/test/Target/LLVMIR/x86vector.mlir        | 16 ++++++
 4 files changed, 101 insertions(+)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 3bf0be0a716aa..c3f7904fa1249 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -302,6 +302,8 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
   }];
 }
 
+
+
 //----------------------------------------------------------------------------//
 // Convert packed F32 to packed BF16
 //----------------------------------------------------------------------------//
@@ -420,6 +422,57 @@ def DotOp : AVX_LowOp<"dot", [Pure,
   }];
 }
 
+//----------------------------------------------------------------------------//
+// AVX Int8 Dot
+//----------------------------------------------------------------------------//
+
+def DotInt8Op : AVX_Op<"dot.i32", [Pure,
+    X86IntrinsicOpInterface,
+    AllTypesMatch<["a", "b"]>,
+    AllTypesMatch<["src", "dst"]>,
+    TypesMatchWith<"`a` has same elements as `src`",
+                   "src", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
+                   "IntegerType::get($_self.getContext(), 32))">
+  ]> {
+  let summary = "Dot Int8 op";
+  let description = [{
+    The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper
+    LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR
+    vectors it is applied to.
+
+    #### From the Intel Intrinsics Guide:
+
+    Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with 
+    corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit 
+    results. Sum these 4 results with the corresponding 32-bit integer in `src`, and 
+    store the packed 32-bit results in `dst`.
+
+    Example:
+    ```mlir
+    %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
+    ```
+  }];
+  let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
+                   VectorOfLengthAndType<[4, 8], [I32]>:$a,
+                   VectorOfLengthAndType<[4, 8], [I32]>:$b
+                   );
+  let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
+  let assemblyFormat =
+    "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+
+  let extraClassDeclaration = [{
+    std::string getIntrinsicName() {
+      std::string intr = "llvm.x86.avx2.vpdpbssd";
+      VectorType vecType = getSrc().getType();
+      unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+      unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+      intr += "." + std::to_string(opBitWidth);
+      return intr;
+    }
+  }];
+}
+
 //----------------------------------------------------------------------------//
 // AVX: Convert BF16/F16 to F32 and broadcast into packed F32
 //----------------------------------------------------------------------------//
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 63f06624ef897..a70410497acbd 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -219,3 +219,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+  // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 7dcab3eb4dcb8..bd3509eb07b2b 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -229,3 +229,19 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: func @avx_dot_i32_128
+func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32>
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: func @avx_dot_i32_256
+func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32>
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index d11dc89bdc7c9..ac2f3e1277df3 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -234,3 +234,19 @@ func.func @LLVM_x86_avx_dp_ps_256(
   %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
   return %0 : vector<8xf32>
 }
+
+// CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
+func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>,
+    %b: vector<4xi32>) -> vector<4xi32> {
+    // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  return %0 : vector<4xi32>
+}
+
+// CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
+func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>,
+    %b: vector<8xi32>) -> vector<8xi32> {
+    // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  return %0 : vector<8xi32>
+}

>From df153cfc5d8e9c60a5fb38b058a3bfffcce37516 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Thu, 10 Jul 2025 00:49:48 -0700
Subject: [PATCH 2/5] removing extra space

---
 mlir/include/mlir/Dialect/X86Vector/X86Vector.td | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index c3f7904fa1249..688f26211e48d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -302,8 +302,6 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
   }];
 }
 
-
-
 //----------------------------------------------------------------------------//
 // Convert packed F32 to packed BF16
 //----------------------------------------------------------------------------//

>From 72933a035fe3dbe25b33d2ae64ea2bbf7191ac6f Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 11 Jul 2025 00:32:13 -0700
Subject: [PATCH 3/5] changed the logic to i8*i8=+i32 with llvm.bitcast

---
 .../mlir/Dialect/X86Vector/X86Vector.td       | 35 +++++++++++--------
 .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 19 ++++++++++
 .../Dialect/X86Vector/legalize-for-llvm.mlir  | 16 ++++-----
 mlir/test/Dialect/X86Vector/roundtrip.mlir    | 20 +++++------
 mlir/test/Target/LLVMIR/x86vector.mlir        | 16 ++++-----
 5 files changed, 65 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 688f26211e48d..73f6877c12fab 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -424,50 +424,55 @@ def DotOp : AVX_LowOp<"dot", [Pure,
 // AVX Int8 Dot
 //----------------------------------------------------------------------------//
 
-def DotInt8Op : AVX_Op<"dot.i32", [Pure,
+def DotInt8Op : AVX_Op<"dot.i8", [Pure,
     X86IntrinsicOpInterface,
     AllTypesMatch<["a", "b"]>,
-    AllTypesMatch<["src", "dst"]>,
-    TypesMatchWith<"`a` has same elements as `src`",
-                   "src", "a",
-                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
-                   "IntegerType::get($_self.getContext(), 32))">
+    AllTypesMatch<["w", "dst"]>,
+    TypesMatchWith<"`a` has four times elements as `w`",
+                   "w", "a",
+                   "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 4}, "
+                   "IntegerType::get($_self.getContext(), 8))">
   ]> {
   let summary = "Dot Int8 op";
   let description = [{
-    The `dot` op is an AVX2-I32/I8 specific op that can lower to the proper
-    LLVMAVX2-INT8/32 operation `llvm.vpdpbssd` depending on the width of MLIR
+    The `dot` op is an AVX2-Int8 specific op that can lower to the proper
+    LLVMAVX2-INT8 operation `llvm.vpdpbssd` depending on the width of MLIR
     vectors it is applied to.
 
     #### From the Intel Intrinsics Guide:
 
     Multiply groups of 4 adjacent pairs of signed 8-bit integers in `a` with 
     corresponding signed 8-bit integers in `b`, producing 4 intermediate signed 16-bit 
-    results. Sum these 4 results with the corresponding 32-bit integer in `src`, and 
+    results. Sum these 4 results with the corresponding 32-bit integer in `w`, and 
     store the packed 32-bit results in `dst`.
 
     Example:
     ```mlir
-    %dst = x86vector.avx.dot %src, %a, %b : vector<8xi32> -> vector<8xi32>
+    %dst = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
     ```
   }];
-  let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$src,
-                   VectorOfLengthAndType<[4, 8], [I32]>:$a,
-                   VectorOfLengthAndType<[4, 8], [I32]>:$b
+  let arguments = (ins VectorOfLengthAndType<[4, 8], [I32]>:$w,
+                   VectorOfLengthAndType<[16, 32], [I8]>:$a,
+                   VectorOfLengthAndType<[16, 32], [I8]>:$b
                    );
   let results = (outs VectorOfLengthAndType<[4, 8], [I32]>:$dst);
   let assemblyFormat =
-    "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+    "$w `,` $a `,` $b attr-dict `:` type($a) `->` type($w)";
 
   let extraClassDeclaration = [{
     std::string getIntrinsicName() {
       std::string intr = "llvm.x86.avx2.vpdpbssd";
-      VectorType vecType = getSrc().getType();
+      VectorType vecType = getW().getType();
       unsigned elemBitWidth = vecType.getElementTypeBitWidth();
       unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
       intr += "." + std::to_string(opBitWidth);
       return intr;
     }
+
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index cc7ab7f3f3895..7dddf55010800 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -86,6 +86,25 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
   return intrinsicOperands;
 }
 
+SmallVector<Value>
+x86vector::DotInt8Op::getIntrinsicOperands(ArrayRef<Value> operands,
+                                       const LLVMTypeConverter &typeConverter,
+                                       RewriterBase &rewriter) {
+  SmallVector<Value, 3> intrinsicOprnds;
+  intrinsicOprnds.push_back(operands[0]);
+
+  //Bit-cast `a` and `b` to i32
+  Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
+                        getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)),
+                        operands[1]);
+  intrinsicOprnds.push_back(bitcast_a);
+  Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
+                        getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)),
+                        operands[2]);
+  intrinsicOprnds.push_back(bitcast_b);
+  return intrinsicOprnds;
+}
+
 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index a70410497acbd..72dc899f4f0a6 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -220,18 +220,18 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: func @avx_dot_i32_128
-func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
-    %b: vector<4xi32>) -> vector<4xi32> {
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
   // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.128"
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
   return %0 : vector<4xi32>
 }
 
-// CHECK-LABEL: func @avx_dot_i32_256
-func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
-    %b: vector<8xi32>) -> vector<8xi32> {
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
   // CHECK: llvm.call_intrinsic "llvm.x86.avx2.vpdpbssd.256"
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
   return %0 : vector<8xi32>
 }
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index bd3509eb07b2b..959177b27c7ea 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -230,18 +230,18 @@ func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
   return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: func @avx_dot_i32_128
-func.func @avx_dot_i32_128(%src: vector<4xi32>, %a: vector<4xi32>,
-    %b: vector<4xi32>) -> vector<4xi32> {
-  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<4xi32> -> vector<4xi32>
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+// CHECK-LABEL: func @avx_dot_i8_128
+func.func @avx_dot_i8_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<16xi8> -> vector<4xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
   return %0 : vector<4xi32>
 }
 
-// CHECK-LABEL: func @avx_dot_i32_256
-func.func @avx_dot_i32_256(%src: vector<8xi32>, %a: vector<8xi32>,
-    %b: vector<8xi32>) -> vector<8xi32> {
-  // CHECK: x86vector.avx.dot.i32 {{.*}} : vector<8xi32> -> vector<8xi32>
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+// CHECK-LABEL: func @avx_dot_i8_256
+func.func @avx_dot_i8_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: x86vector.avx.dot.i8 {{.*}} : vector<32xi8> -> vector<8xi32>
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
   return %0 : vector<8xi32>
 }
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index ac2f3e1277df3..74ae2424964b1 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -236,17 +236,17 @@ func.func @LLVM_x86_avx_dp_ps_256(
 }
 
 // CHECK-LABEL: define <4 x i32> @LLVM_x86_avx2_vpdpbssd_128
-func.func @LLVM_x86_avx2_vpdpbssd_128(%src: vector<4xi32>, %a: vector<4xi32>,
-    %b: vector<4xi32>) -> vector<4xi32> {
-    // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<4xi32> -> vector<4xi32>
+func.func @LLVM_x86_avx2_vpdpbssd_128(%w: vector<4xi32>, %a: vector<16xi8>,
+    %b: vector<16xi8>) -> vector<4xi32> {
+  // CHECK: call <4 x i32> @llvm.x86.avx2.vpdpbssd.128(
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<16xi8> -> vector<4xi32>
   return %0 : vector<4xi32>
 }
 
 // CHECK-LABEL: define <8 x i32> @LLVM_x86_avx2_vpdpbssd_256
-func.func @LLVM_x86_avx2_vpdpbssd_256(%src: vector<8xi32>, %a: vector<8xi32>,
-    %b: vector<8xi32>) -> vector<8xi32> {
-    // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
-  %0 = x86vector.avx.dot.i32 %src, %a, %b : vector<8xi32> -> vector<8xi32>
+func.func @LLVM_x86_avx2_vpdpbssd_256(%w: vector<8xi32>, %a: vector<32xi8>,
+    %b: vector<32xi8>) -> vector<8xi32> {
+  // CHECK: call <8 x i32> @llvm.x86.avx2.vpdpbssd.256(
+  %0 = x86vector.avx.dot.i8 %w, %a, %b : vector<32xi8> -> vector<8xi32>
   return %0 : vector<8xi32>
 }

>From c902a46f59414d20a824b905f741398698378400 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 11 Jul 2025 00:35:02 -0700
Subject: [PATCH 4/5] clang-format on c++ file

---
 .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 23 +++++++++++--------
 1 file changed, 13 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 7dddf55010800..64d6790306941 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -86,22 +86,25 @@ x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
   return intrinsicOperands;
 }
 
-SmallVector<Value>
-x86vector::DotInt8Op::getIntrinsicOperands(ArrayRef<Value> operands,
-                                       const LLVMTypeConverter &typeConverter,
-                                       RewriterBase &rewriter) {
+SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
   SmallVector<Value, 3> intrinsicOprnds;
   intrinsicOprnds.push_back(operands[0]);
-
-  //Bit-cast `a` and `b` to i32
+  // Bitcast `a` and `b` to i32
   Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
-                        getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)),
-                        operands[1]);
+      getLoc(),
+      VectorType::get((getA().getType().getShape()[0] / 4),
+                      rewriter.getIntegerType(32)),
+      operands[1]);
   intrinsicOprnds.push_back(bitcast_a);
   Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
-                        getLoc(), VectorType::get((getA().getType().getShape()[0]/4), rewriter.getIntegerType(32)),
-                        operands[2]);
+      getLoc(),
+      VectorType::get((getB().getType().getShape()[0] / 4),
+                      rewriter.getIntegerType(32)),
+      operands[2]);
   intrinsicOprnds.push_back(bitcast_b);
+
   return intrinsicOprnds;
 }
 

>From e7ae0717a4d152127de9d139a7fa4bab0d2d4b26 Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Fri, 11 Jul 2025 04:05:20 -0700
Subject: [PATCH 5/5] cleanup on operands in cpp

---
 mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 64d6790306941..68aea48561283 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -90,19 +90,20 @@ SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
     ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
     RewriterBase &rewriter) {
   SmallVector<Value, 3> intrinsicOprnds;
-  intrinsicOprnds.push_back(operands[0]);
+  Adaptor adaptor(operands, *this);
+  intrinsicOprnds.push_back(adaptor.getW());
   // Bitcast `a` and `b` to i32
   Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
       getLoc(),
       VectorType::get((getA().getType().getShape()[0] / 4),
                       rewriter.getIntegerType(32)),
-      operands[1]);
+      adaptor.getA());
   intrinsicOprnds.push_back(bitcast_a);
   Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
       getLoc(),
       VectorType::get((getB().getType().getShape()[0] / 4),
                       rewriter.getIntegerType(32)),
-      operands[2]);
+      adaptor.getB());
   intrinsicOprnds.push_back(bitcast_b);
 
   return intrinsicOprnds;



More information about the Mlir-commits mailing list