[Mlir-commits] [mlir] f68af65 - [mlir][x86] Extends vector.contract Flat dot-product lowering for arg offset (#185167)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 8 20:56:31 PDT 2026
Author: Arun Thangamani
Date: 2026-03-09T09:26:27+05:30
New Revision: f68af65300181e27f097e6b6b4c5bdef9c77cc3c
URL: https://github.com/llvm/llvm-project/commit/f68af65300181e27f097e6b6b4c5bdef9c77cc3c
DIFF: https://github.com/llvm/llvm-project/commit/f68af65300181e27f097e6b6b4c5bdef9c77cc3c.diff
LOG: [mlir][x86] Extends vector.contract Flat dot-product lowering for arg offset (#185167)
Extends `vector_contract_to_packed_type_dot_product` transform pass to
include `args` offset check while validating the `vector.contract`.
Eg: `vector.transfer_read %arg1[%arg3, %c0], %0 {in_bounds = [true,
true]} : !memref, !vec`
Added:
Modified:
mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
mlir/lib/Dialect/X86/Utils/X86Utils.cpp
mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index e3037186569b8..b47eede2a9156 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -67,7 +67,15 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
mlir::vector::ContractionOp contractB,
int64_t nonUnitDimAcc,
mlir::VectorType Ty) {
- mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
+
+ bool opABeforeopB = opA->isBeforeInBlock(opB);
+
+ if (opABeforeopB)
+ rewriter.moveOpAfter(opB, opA);
+ else
+ rewriter.moveOpAfter(opA, opB);
+
+ mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
rewriter.setInsertionPointAfter(insertAfter);
mlir::Location loc = insertAfter->getLoc();
@@ -326,14 +334,6 @@ struct VectorContractToPackedTypeDotProduct
return rewriter.notifyMatchFailure(
contractOp, "Could not find a valid contract pair");
- if (contractOp->getBlock() ==
- nonUnitDimReadOpPairContract->getBlock() &&
- contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
- return rewriter.notifyMatchFailure(
- contractOp,
- "The load/read operation of pair contract operation is "
- "after the contractOp");
-
VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
? contractOp.getRhsType()
: contractOp.getLhsType();
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index 805d9c5c00b63..3893d8d288f32 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -341,6 +341,9 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
vector::ContractionOp pairContOp,
bool rhsHasMultipleNonUnitDims,
int64_t nonUnitDimValue) {
+ if (contractOp == pairContOp)
+ return false;
+
if (rhsHasMultipleNonUnitDims &&
!(contractOp.getLhs() == pairContOp.getLhs()))
return false;
@@ -393,21 +396,25 @@ bool validatePairVectorContract(vector::ContractionOp contractOp,
if (srcBuff != srcBuffPairContOp)
return false;
+ bool oneConstantOffset = false;
for (size_t i = 0; i < indexVals.size(); i++) {
+
+ if (indexVals[i] == indexValsPairContOp[i])
+ continue;
+
auto v0 = getConstantIntValue(indexVals[i]);
auto v1 = getConstantIntValue(indexValsPairContOp[i]);
if (!v0 || !v1)
return false;
- if (*v1 == *v0)
- continue;
-
if ((*v1 - *v0) != nonUnitDimValue)
return false;
+
+ oneConstantOffset = true;
}
- return true;
+ return oneConstantOffset;
}
} // namespace x86
diff --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
index eabf15c0af303..0953ee042a24d 100644
--- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
@@ -412,6 +412,76 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x2xbf16>
+!vecB = vector<2x16xbf16>
+!vecC = vector<1x16xf32>
+!memrefA = memref<4x2xbf16>
+!memrefB = memref<2x32xbf16>
+!memrefC = memref<2x32xf32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC, %arg3: index) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : bf16
+ %32 = ub.poison : f32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %arg1[%arg3, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %5 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %4, %2
+ : !vecA, !vecB into !vecC
+
+ %6 = vector.transfer_read %arg1[%arg3, %c16], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %6, %3
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %5, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_bf16dp_flat_layout_offset_args_and_read_after_vc
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59] : vector<32xbf16>, vector<32xbf16>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63] : vector<32xbf16>, vector<32xbf16>
+// CHECK: x86.avx512.dot
+// CHECK: x86.avx512.dot
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x1x2xbf16>
!vecB = vector<1x2x16xbf16>
!vecC = vector<1x16xf32>
More information about the Mlir-commits
mailing list