[Mlir-commits] [mlir] [mlir][x86] Lower Int8 vector.contract to AVX2/AVX10 dp (online packing) (PR #189386)

Arun Thangamani llvmlistbot at llvm.org
Mon Mar 30 06:57:54 PDT 2026


https://github.com/arun-thmn created https://github.com/llvm/llvm-project/pull/189386

None

>From 62af1799bcf421abd8121bb648334335cde65a2e Mon Sep 17 00:00:00 2001
From: Arun Thangamani <arun.thangamani at intel.com>
Date: Mon, 30 Mar 2026 06:43:33 -0700
Subject: [PATCH] initial commit to int8 flat avx2 and avx10 support

---
 .../VectorContractToPackedTypeDotProduct.cpp  |  82 ++++-
 mlir/lib/Dialect/X86/Utils/X86Utils.cpp       |  18 +-
 ...or-contract-to-packed-type-dotproduct.mlir | 297 ++++++++++++++++++
 3 files changed, 380 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a4496f3620b97..0e4b39045571d 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -82,22 +82,78 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
   auto elemTy = Ty.getElementType();
   auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
 
-  auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opA->getResult(0));
-  auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
-                                                 opB->getResult(0));
+  Value srcBuff;
+  SmallVector<Value> indexVals;
 
-  static constexpr int64_t maskLo[] = {
+  llvm::TypeSwitch<Operation *>(opA).Case<TransferReadOp, LoadOp>(
+      [&](auto readOp) {
+        srcBuff = readOp.getOperand(0);
+
+        auto indices = readOp.getIndices();
+        indexVals.reserve(indices.size());
+
+        llvm::transform(
+            indices, std::back_inserter(indexVals), [&](OpFoldResult ofr) {
+              return mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+            });
+      });
+
+  auto vec1 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals);
+
+  unsigned int offset = 1;
+  if (elemTy.isSignlessInteger(8))
+    offset = 2;
+
+  Value cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+  auto nextIndx =
+      arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), cOffset,
+                            indexVals[indexVals.size() - 2]);
+  indexVals[indexVals.size() - 2] = nextIndx;
+
+  auto vec2 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals);
+
+  static constexpr int64_t maskLo_bf16[] = {
       0,  32, 1,  33, 2,  34, 3,  35, 8,  40, 9,  41, 10, 42, 11, 43,
       16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
-  static constexpr int64_t maskHi[] = {
+  static constexpr int64_t maskHi_bf16[] = {
       4,  36, 5,  37, 6,  38, 7,  39, 12, 44, 13, 45, 14, 46, 15, 47,
       20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
 
-  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, maskLo);
-  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
-                                                   castB, maskHi);
+  static constexpr int64_t maskLo_int8_avx2[] = {
+      0, 16, 32, 48, 1, 17, 33, 49, 2,  18, 34, 50, 3,  19, 35, 51,
+      8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
+  static constexpr int64_t maskHi_int8_avx2[] = {
+      4,  20, 36, 52, 5,  21, 37, 53, 6,  22, 38, 54, 7,  23, 39, 55,
+      12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
+
+  static constexpr int64_t maskLo_int8[] = {
+      0,  32, 64, 96,  1,  33, 65, 97,  2,  34, 66, 98,  3,  35, 67, 99,
+      8,  40, 72, 104, 9,  41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
+      16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
+      24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
+  static constexpr int64_t maskHi_int8[] = {
+      4,  36, 68, 100, 5,  37, 69, 101, 6,  38, 70, 102, 7,  39, 71, 103,
+      12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
+      20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
+      28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
+
+  mlir::DenseI64ArrayAttr maskLo = rewriter.getDenseI64ArrayAttr(maskLo_bf16);
+  mlir::DenseI64ArrayAttr maskHi = rewriter.getDenseI64ArrayAttr(maskHi_bf16);
+
+  if (elemTy.isSignlessInteger(8)) {
+    maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8);
+    maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8);
+
+    if (nonUnitDimAcc == 32) {
+      maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx2);
+      maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx2);
+    }
+  }
+
+  auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
+                                                   vec2, maskLo);
+  auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
+                                                   vec2, maskHi);
 
   auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
   auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
@@ -159,8 +215,8 @@ struct VectorContractToPackedTypeDotProduct
         isInVnniLayout(contractOp.getOperation(),
                        contractOp.getIndexingMapsArray(), blockingFactor);
 
-    if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
-      return failure();
+    // if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
+    // return failure();
 
     VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
     if (!accTy)
@@ -217,7 +273,7 @@ struct VectorContractToPackedTypeDotProduct
 
     if (!isVnni && (extraFlatDim != blockingFactor))
       return rewriter.notifyMatchFailure(
-          contractOp, "The K or reduction dim for flat layout should be 2.");
+          contractOp, "The K or reduction dim for flat layout should be 2/4.");
 
     if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
         (lhsTy.getElementType().isSignlessInteger(8) &&
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index f4279a3eb507a..aceea754639bb 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -116,15 +116,20 @@ struct ShuffleMasks {
   llvm::ArrayRef<int64_t> maskHi;
 };
 
-inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2) {
   // We only support these two layouts for now.
   assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
          "Unsupported nonUnitDimAcc value");
+
   // Do interleaving between two <8xf32> targeting AVX2.
   static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
   static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
 
-  // Shuffle two <16xf32> as below targeting AVX512.
+  // Do interleaving between two <8xf32> targeting AVX2.
+  static constexpr int64_t maskLo8_avx2_int8[] = {0, 1, 2, 3, 8, 9, 10, 11};
+  static constexpr int64_t maskHi8_avx2_int8[] = {4, 5, 6, 7, 12, 13, 14, 15};
+
+  // Shuffle two <16xf32/i32> as below targeting AVX512.
   static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
                                          4, 5, 6, 7, 20, 21, 22, 23};
   static constexpr int64_t maskHi16[] = {8,  9,  10, 11, 24, 25, 26, 27,
@@ -133,6 +138,9 @@ inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
   if (nonUnitDimAcc == 16)
     return {maskLo16, maskHi16};
 
+  if (isInt8Avx2)
+    return {maskLo8_avx2_int8, maskHi8_avx2_int8};
+
   return {maskLo8, maskHi8};
 }
 
@@ -255,7 +263,8 @@ LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
   auto castB =
       vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
 
-  auto masks = getShuffleMasks(nonUnitDimAcc);
+  auto masks = getShuffleMasks(
+      nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
 
   auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
                                              castB, masks.maskLo);
@@ -313,7 +322,8 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
   auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
 
   // TODO: derive shuffle masks instead of hard-coding
-  auto masks = getShuffleMasks(nonUnitDimAcc);
+  auto masks = getShuffleMasks(
+      nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
 
   auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
                                               castB, masks.maskLo);
diff --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
index 0953ee042a24d..f861d357739a3 100644
--- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
@@ -412,6 +412,144 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x4xi8>
+!vecB = vector<4x16xi8>
+!vecC = vector<1x16xi32>
+!memrefA = memref<4x4xi8>
+!memrefB = memref<4x64xi8>
+!memrefC = memref<4x64xi32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_i8_avx10dp_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %4, %2
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %5, %3
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_i8_avx10dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.avx10.dot.i8
+// CHECK: x86.avx10.dot.i8
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x4xi8>
+!vecB = vector<4x8xi8>
+!vecC = vector<1x8xi32>
+!memrefA = memref<4x4xi8>
+!memrefB = memref<4x64xi8>
+!memrefC = memref<4x64xi32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_i8_avx2dp_flat_layout(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %5 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %4, %2
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %5, %3
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_i8_avx2dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32>
+// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59] : vector<32xi8>, vector<32xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8>
+// CHECK: x86.avx.dot.i8
+// CHECK: x86.avx.dot.i8
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x2xbf16>
 !vecB = vector<2x16xbf16>
 !vecC = vector<1x16xf32>
@@ -640,6 +778,102 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled(
 // CHECK: x86.avx512.dot
 // CHECK-NOT: vector.contract
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x4x16xi8>
+!vecC = vector<1x16xi32>
+!memrefA = memref<1x2x4xi8, strided<[16384, 256, 1], offset: ?>>
+!memrefB = memref<1x4x32xi8, strided<[32768, 128, 1], offset: ?>>
+!memrefC = memref<2x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @brgemm_int8_flat_avx10(%arg0: memref<16x64x256xi8>, %arg1: memref<16x256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i8
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  %c2 = arith.constant 2 : index
+  %c32 = arith.constant 32 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg3 = %c0 to %c64 step %c2 {
+    scf.for %arg4 = %c0 to %c128 step %c32 {
+      %subview = memref.subview %arg2[%arg3, %arg4] [2, 32] [1, 1] : memref<64x128xi32> to !memrefC
+      %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %4 = vector.transfer_read %subview[%c1, %c0], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %5 = vector.transfer_read %subview[%c1, %c16], %0 {in_bounds = [true, true]}
+                : !memrefC, !vecC
+      %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+        %7:4 = scf.for %arg10 = %c0 to %c256 step %c4 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+          %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 2, 4] [1, 1, 1] : memref<16x64x256xi8> to !memrefA
+          %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 4, 32] [1, 1, 1] : memref<16x256x128xi8> to !memrefB
+          %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+                : !memrefA, !vecA
+          %9 = vector.transfer_read %subview_0[%c0, %c1, %c0], %1 {in_bounds = [true, true, true]}
+                : !memrefA, !vecA
+          %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+                : !memrefB, !vecB
+          %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]}
+                : !memrefB, !vecB
+          %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %10, %arg11 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+          %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %8, %11, %arg12 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+          %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %9, %10, %arg13 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+          %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+                ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                %9, %11, %arg14 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+          scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+        }
+        scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+      }
+      vector.transfer_write %6#3, %subview[%c1, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#2, %subview[%c1, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+      vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+    }
+  }
+  %alloc = memref.alloc() : memref<64x128xi32>
+  memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+  return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @brgemm_int8_flat_avx10
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.avx10.dot.i8
+// CHECK: x86.avx10.dot.i8
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -1548,3 +1782,66 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+!vecA = vector<1x4xi8>
+!vecB = vector<4x8xi8>
+!vecC = vector<1x8xi32>
+!memrefA = memref<4x4xi8>
+!memrefB = memref<4x64xi8>
+!memrefC = memref<4x64xi32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @negative_i8_avx2dp_flat_layout_offset_diff_16(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : i8
+  %32 = ub.poison : i32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+  %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %6 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %4, %2
+    : !vecA, !vecB into !vecC
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %5, %3
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @negative_i8_avx2dp_flat_layout_offset_diff_16
+// CHECK-NOT: x86.avx.dot.i8
+// CHECK: vector.contract
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list