[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to FMA or packed type dot-product (PR #168074)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Nov 20 06:08:32 PST 2025
================
@@ -0,0 +1,99 @@
+//===- VectorContractToFMA.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Implements outer product contraction as a sequence of broadcast and
+// FMA operations.
+//
+// For example - for F32 type:
+// ```
+// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
+// ```
+// to
+// ```
+// vector.broadcast %lhs to <16xf32>
+// vector.fma vector<16xf32>
+// ```
+struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ VectorType lhsTy = contractOp.getLhsType();
+ if (!lhsTy.getElementType().isF32())
----------------
adam-smnk wrote:
As silly as it looks:
```mlir
!vecA = vector<1x1x1xf32>
!vecB = vector<1x1x64xf32>
!vecC = vector<1x1x64xi32>
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
func.func @test(
%arg0: !vecA, %arg1: !vecB, %arg2: !vecC) -> !vecC
{
%0 = vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%arg0, %arg1, %arg2
: !vecA, !vecB into !vecC
return %0 : !vecC
}
```
is valid IR which can't be handled by this rewrite so, it's best to reject it.
https://github.com/llvm/llvm-project/pull/168074
More information about the Mlir-commits
mailing list