[Mlir-commits] [mlir] f68af65 - [mlir][x86] Extends vector.contract Flat dot-product lowering for arg offset (#185167)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 8 20:56:31 PDT 2026


Author: Arun Thangamani
Date: 2026-03-09T09:26:27+05:30
New Revision: f68af65300181e27f097e6b6b4c5bdef9c77cc3c

URL: https://github.com/llvm/llvm-project/commit/f68af65300181e27f097e6b6b4c5bdef9c77cc3c
DIFF: https://github.com/llvm/llvm-project/commit/f68af65300181e27f097e6b6b4c5bdef9c77cc3c.diff

LOG: [mlir][x86] Extends vector.contract Flat dot-product lowering for arg offset (#185167)

Extends `vector_contract_to_packed_type_dot_product` transform pass to
include `args` offset check while validating the `vector.contract`.
Eg: `vector.transfer_read %arg1[%arg3, %c0], %0 {in_bounds = [true,
true]} : !memref, !vec`

Added: 
    

Modified: 
    mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
    mlir/lib/Dialect/X86/Utils/X86Utils.cpp
    mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index e3037186569b8..b47eede2a9156 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -67,7 +67,15 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
                                         mlir::vector::ContractionOp contractB,
                                         int64_t nonUnitDimAcc,
                                         mlir::VectorType Ty) {
-  mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+  bool opABeforeopB = opA->isBeforeInBlock(opB);
+
+  if (opABeforeopB)
+    rewriter.moveOpAfter(opB, opA);
+  else
+    rewriter.moveOpAfter(opA, opB);
+
+  mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
 
   rewriter.setInsertionPointAfter(insertAfter);
   mlir::Location loc = insertAfter->getLoc();
@@ -326,14 +334,6 @@ struct VectorContractToPackedTypeDotProduct
           return rewriter.notifyMatchFailure(
               contractOp, "Could not find a valid contract pair");
 
-        if (contractOp->getBlock() ==
-                nonUnitDimReadOpPairContract->getBlock() &&
-            contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
-          return rewriter.notifyMatchFailure(
-              contractOp,
-              "The load/read operation of pair contract operation is "
-              "after the contractOp");
-
         VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
                                       ? contractOp.getRhsType()
                                       : contractOp.getLhsType();

diff  --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index 805d9c5c00b63..3893d8d288f32 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -341,6 +341,9 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
                                 vector::ContractionOp pairContOp,
                                 bool rhsHasMultipleNonUnitDims,
                                 int64_t nonUnitDimValue) {
+  if (contractOp == pairContOp)
+    return false;
+
   if (rhsHasMultipleNonUnitDims &&
       !(contractOp.getLhs() == pairContOp.getLhs()))
     return false;
@@ -393,21 +396,25 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
   if (srcBuff != srcBuffPairContOp)
     return false;
 
+  bool oneConstantOffset = false;
   for (size_t i = 0; i < indexVals.size(); i++) {
+
+    if (indexVals[i] == indexValsPairContOp[i])
+      continue;
+
     auto v0 = getConstantIntValue(indexVals[i]);
     auto v1 = getConstantIntValue(indexValsPairContOp[i]);
 
     if (!v0 || !v1)
       return false;
 
-    if (*v1 == *v0)
-      continue;
-
     if ((*v1 - *v0) != nonUnitDimValue)
       return false;
+
+    oneConstantOffset = true;
   }
 
-  return true;
+  return oneConstantOffset;
 }
 
 } // namespace x86

diff  --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
index eabf15c0af303..0953ee042a24d 100644
--- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
@@ -412,6 +412,76 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0,  d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0,  d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc(
+  %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %0 = ub.poison : bf16
+  %32 = ub.poison : f32
+  %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+        !memrefA, !vecA
+  %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+        !memrefC, !vecC
+  %4 = vector.transfer_read %arg1[%arg3, %c0], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %5 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %4, %2
+    : !vecA, !vecB into !vecC
+
+  %6 = vector.transfer_read %arg1[%arg3, %c16], %0 {in_bounds = [true, true]} :
+        !memrefB, !vecB
+
+  %7 = vector.contract {
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>}
+    %1, %6, %3
+    : !vecA, !vecB into !vecC
+
+  vector.transfer_write %5, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+  vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+  return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.avx512.dot
+// CHECK: x86.avx512.dot
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 !vecA = vector<1x1x2xbf16>
 !vecB = vector<1x2x16xbf16>
 !vecC = vector<1x16xf32>


        


More information about the Mlir-commits mailing list