[Mlir-commits] [mlir] [draft] publicize vec contract lowering strategy (PR #196642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 8 13:47:43 PDT 2026
https://github.com/efric created https://github.com/llvm/llvm-project/pull/196642
None
>From 50fe986c5e677d8ddac90b334a9bb1003e73c971 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Mon, 27 Apr 2026 21:35:05 -0700
Subject: [PATCH 1/4] experimental expose vector contract lowerings to have
multiple options
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
add more tests and parallelarith
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
Keep contract lowering filter patch focused
Reuse filter constraint naming for contract lowering policy
Rename vector contract default filter
---
.../Vector/Transforms/LoweringPatterns.h | 31 ++++
.../Vector/Transforms/LowerVectorContract.cpp | 154 ++++++++++++------
.../vector-contract-composable-lowering.mlir | 118 ++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 68 ++++++++
4 files changed, 323 insertions(+), 48 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..1c98b364d7e0f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -12,6 +12,8 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include <functional>
+
namespace mlir {
class RewritePatternSet;
@@ -47,6 +49,35 @@ namespace vector {
/// [ContractionOpToOuterProductOpLowering]
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
+
+/// A `FilterConstraintType` lets clients compose multiple lowering strategies
+/// by benefit. Returning failure means this strategy silently declines the op
+/// without consuming it or diagnosing invalid IR; lower-benefit strategies may
+/// still match the same op.
+using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
+
+LogicalResult defaultFilter(ContractionOp op);
+
+void populateVectorContractToDotPatterns(
+ RewritePatternSet &patterns,
+ FilterConstraintType filter = defaultFilter,
+ PatternBenefit benefit = 1);
+
+void populateVectorContractToOuterProductPatterns(
+ RewritePatternSet &patterns,
+ FilterConstraintType filter = defaultFilter,
+ PatternBenefit benefit = 1);
+
+void populateVectorContractToParallelArithPatterns(
+ RewritePatternSet &patterns,
+ FilterConstraintType filter = defaultFilter,
+ PatternBenefit benefit = 1);
+
+void populateVectorContractGenericLoweringPatterns(
+ RewritePatternSet &patterns,
+ FilterConstraintType filter = defaultFilter,
+ PatternBenefit benefit = 1);
+
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption,
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index eaf7bb8109514..9ae75fc5423f8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -223,13 +223,6 @@ class ContractionOpToOuterProductOpLowering
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
ContractionOpToOuterProductOpLowering(
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
@@ -271,19 +264,13 @@ class ContractionOpToDotLowering
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
-
ContractionOpToDotLowering(
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
- const FilterConstraintType &constraint = defaultFilter)
+ FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering),
+ filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -309,33 +296,34 @@ class ContractionOpToDotLowering
///
/// This only kicks in when either VectorTransformsOptions is set
/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering
+class ContractionOpGenericLowering
: public MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
-
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
- ContractionOpLowering(
- vector::VectorContractLowering vectorContractLoweringOption,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
+ ContractionOpGenericLowering(
+ MLIRContext *context,
+ FilterConstraintType constraint = defaultFilter,
+ PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorContractLoweringOption(vectorContractLoweringOption),
filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
PatternRewriter &rewriter) const override;
+protected:
+ LogicalResult
+ matchSupportedGenericContraction(vector::ContractionOp op,
+ PatternRewriter &rewriter) const;
+
+ FailureOr<Value> lowerGenericContraction(PatternRewriter &rewriter,
+ vector::ContractionOp op,
+ MaskingOpInterface maskOp) const;
+
private:
- /// Options to control the vector patterns.
- vector::VectorContractLowering vectorContractLoweringOption;
FilterConstraintType filter;
+
// Lower one parallel dimension.
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
vector::ContractionOp op, int64_t lhsIndex,
@@ -345,6 +333,24 @@ class ContractionOpLowering
vector::ContractionOp op, Value mask) const;
};
+class ContractionOpLowering : public ContractionOpGenericLowering {
+public:
+ ContractionOpLowering(
+ vector::VectorContractLowering vectorContractLoweringOption,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
+ : ContractionOpGenericLowering(context, std::move(constraint), benefit),
+ vectorContractLoweringOption(vectorContractLoweringOption) {}
+
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// Options to control the vector patterns.
+ vector::VectorContractLowering vectorContractLoweringOption;
+};
+
/// Generate a vector implementation for matmat, matvec and tmatvec.
/// This unrolls outer-products along the reduction dimension.
struct UnrolledOuterProductGenerator
@@ -744,17 +750,14 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
struct ContractOpToElementwise
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- using FilterConstraintType =
- std::function<LogicalResult(vector::ContractionOp op)>;
- static LogicalResult defaultFilter(vector::ContractionOp op) {
- return success();
- }
+
ContractOpToElementwise(
vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 1,
- const FilterConstraintType &constraint = defaultFilter)
+ FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
- vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
+ vectorContractLowering(vectorContractLowering),
+ filter(std::move(constraint)) {}
FailureOr<Value>
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -880,9 +883,8 @@ struct ContractOpToElementwise
// TODO: break down into transpose/reshape/cast ops
// when they become available to avoid code dup
// TODO: investigate lowering order impact on performance
-FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
- vector::ContractionOp op, MaskingOpInterface maskOp,
- PatternRewriter &rewriter) const {
+LogicalResult ContractionOpGenericLowering::matchSupportedGenericContraction(
+ vector::ContractionOp op, PatternRewriter &rewriter) const {
if (failed(filter(op)))
return failure();
@@ -894,10 +896,26 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// TODO: the code below assumes the default contraction, make sure it supports
// other kinds before enabling this lowering.
- if (op.getKind() != vector::CombiningKind::ADD) {
+ if (op.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(
op, "contractions other than 'add' not supported");
- }
+ return success();
+}
+
+FailureOr<Value> ContractionOpGenericLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
+ if (failed(matchSupportedGenericContraction(op, rewriter)))
+ return failure();
+
+ return lowerGenericContraction(rewriter, op, maskOp);
+}
+
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
+ if (failed(matchSupportedGenericContraction(op, rewriter)))
+ return failure();
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
@@ -920,6 +938,12 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
if (!failed(newVal4))
return newVal4;
+ return lowerGenericContraction(rewriter, op, maskOp);
+}
+
+FailureOr<Value> ContractionOpGenericLowering::lowerGenericContraction(
+ PatternRewriter &rewriter, vector::ContractionOp op,
+ MaskingOpInterface maskOp) const {
// Vector mask setup.
Value mask;
@@ -982,11 +1006,9 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
// Lower one parallel dimension.
// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
// TODO: consider reusing existing contract unrolling
-FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
- vector::ContractionOp op,
- int64_t lhsIndex,
- int64_t rhsIndex,
- Value mask) const {
+FailureOr<Value> ContractionOpGenericLowering::lowerParallel(
+ PatternRewriter &rewriter, vector::ContractionOp op, int64_t lhsIndex,
+ int64_t rhsIndex, Value mask) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = cast<VectorType>(op.getResultType());
@@ -1069,7 +1091,7 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
}
// Lower one reduction dimension.
-FailureOr<Value> ContractionOpLowering::lowerReduction(
+FailureOr<Value> ContractionOpGenericLowering::lowerReduction(
PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
@@ -1229,6 +1251,42 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
} // namespace
+LogicalResult
+mlir::vector::defaultFilter(ContractionOp) {
+ return success();
+}
+
+void mlir::vector::populateVectorContractToDotPatterns(
+ RewritePatternSet &patterns, FilterConstraintType filter,
+ PatternBenefit benefit) {
+ patterns.add<ContractionOpToDotLowering>(vector::VectorContractLowering::Dot,
+ patterns.getContext(), benefit,
+ std::move(filter));
+}
+
+void mlir::vector::populateVectorContractToOuterProductPatterns(
+ RewritePatternSet &patterns, FilterConstraintType filter,
+ PatternBenefit benefit) {
+ patterns.add<ContractionOpToOuterProductOpLowering>(
+ vector::VectorContractLowering::OuterProduct, patterns.getContext(),
+ benefit, std::move(filter));
+}
+
+void mlir::vector::populateVectorContractToParallelArithPatterns(
+ RewritePatternSet &patterns, FilterConstraintType filter,
+ PatternBenefit benefit) {
+ patterns.add<ContractOpToElementwise>(
+ vector::VectorContractLowering::ParallelArith, patterns.getContext(),
+ benefit, std::move(filter));
+}
+
+void mlir::vector::populateVectorContractGenericLoweringPatterns(
+ RewritePatternSet &patterns, FilterConstraintType filter,
+ PatternBenefit benefit) {
+ patterns.add<ContractionOpGenericLowering>(patterns.getContext(),
+ std::move(filter), benefit);
+}
+
void mlir::vector::populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
diff --git a/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir b/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir
new file mode 100644
index 0000000000000..2f755afb92951
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir
@@ -0,0 +1,118 @@
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=composed" --split-input-file | FileCheck %s --check-prefix=COMPOSED
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=generic" --split-input-file | FileCheck %s --check-prefix=GENERIC
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=parallel-arith" --split-input-file | FileCheck %s --check-prefix=PARALLEL_ACCEPT
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=parallel-arith-reject" --split-input-file | FileCheck %s --check-prefix=PARALLEL_REJECT
+
+#matmat_accesses = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// COMPOSED-LABEL: func @dot_accept
+// COMPOSED-NOT: vector.outerproduct
+// COMPOSED: vector.reduction <add>
+func.func @dot_accept(%A: vector<2x4xf32>,
+ %B: vector<4x3xf32>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = vector.contract #matmat_trait %A, %B, %C
+ : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+// COMPOSED-LABEL: func @dot_reject_to_outerproduct
+// COMPOSED: vector.outerproduct
+func.func @dot_reject_to_outerproduct(%A: vector<2x4xf32>,
+ %B: vector<4x3xf32>,
+ %C: vector<2x3xf32>)
+ -> vector<2x3xf32> {
+ %0 = vector.contract #matmat_trait %A, %B, %C
+ : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+ return %0 : vector<2x3xf32>
+}
+
+#batch_matmul_accesses = [
+ affine_map<(b, m, n, k) -> (b, m, k)>,
+ affine_map<(b, m, n, k) -> (b, k, n)>,
+ affine_map<(b, m, n, k) -> (b, m, n)>
+]
+#batch_matmul_trait = {
+ indexing_maps = #batch_matmul_accesses,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+// COMPOSED-LABEL: func @dot_structural_failure_to_generic
+// COMPOSED-NOT: vector.outerproduct
+// COMPOSED: vector.extract
+// COMPOSED: vector.reduction <add>
+func.func @dot_structural_failure_to_generic(%A: vector<2x2x4xf32>,
+ %B: vector<2x4x3xf32>,
+ %C: vector<2x2x3xf32>)
+ -> vector<2x2x3xf32> {
+ %0 = vector.contract #batch_matmul_trait %A, %B, %C
+ : vector<2x2x4xf32>, vector<2x4x3xf32> into vector<2x2x3xf32>
+ return %0 : vector<2x2x3xf32>
+}
+
+// -----
+
+#dotp_accesses = [
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> (i)>,
+ affine_map<(i) -> ()>
+]
+#dotp_add_trait = {
+ indexing_maps = #dotp_accesses,
+ iterator_types = ["reduction"]
+}
+#dotp_mul_trait = {
+ indexing_maps = #dotp_accesses,
+ iterator_types = ["reduction"],
+ kind = #vector.kind<mul>
+}
+
+// GENERIC-LABEL: func @generic_add
+// GENERIC: arith.mulf
+// GENERIC: vector.reduction <add>
+func.func @generic_add(%A: vector<4xf32>, %B: vector<4xf32>,
+ %C: f32) -> f32 {
+ %0 = vector.contract #dotp_add_trait %A, %B, %C
+ : vector<4xf32>, vector<4xf32> into f32
+ return %0 : f32
+}
+
+// GENERIC-LABEL: func @generic_non_add
+// GENERIC: vector.contract
+// GENERIC-SAME: kind = #vector.kind<mul>
+func.func @generic_non_add(%A: vector<4xf32>, %B: vector<4xf32>,
+ %C: f32) -> f32 {
+ %0 = vector.contract #dotp_mul_trait %A, %B, %C
+ : vector<4xf32>, vector<4xf32> into f32
+ return %0 : f32
+}
+
+// -----
+
+// PARALLEL_ACCEPT-LABEL: func @parallel_arith
+// PARALLEL_ACCEPT-NOT: vector.contract
+// PARALLEL_ACCEPT: vector.fma
+// PARALLEL_REJECT-LABEL: func @parallel_arith
+// PARALLEL_REJECT: vector.contract
+func.func @parallel_arith(
+ %A: vector<1x1x4xf32>, %B: vector<1x1x4xf32>,
+ %C: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.contract {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
+ affine_map<(d0, d1, d2) -> (d0)>
+ ],
+ iterator_types = ["parallel", "reduction", "reduction"],
+ kind = #vector.kind<add>
+ } %A, %B, %C : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ff3520a286cc8..13f13b7fba60a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include <optional>
+#include <string>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -138,6 +139,71 @@ struct TestVectorContractionPrepareForMMTLowering
}
};
+static bool parentFunctionNameContains(vector::ContractionOp op,
+ StringRef substring) {
+ if (auto funcOp = op->getParentOfType<func::FuncOp>())
+ return funcOp.getName().contains(substring);
+ return false;
+}
+
+struct TestVectorContractLoweringComposition final
+ : public PassWrapper<TestVectorContractLoweringComposition,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorContractLoweringComposition)
+
+ TestVectorContractLoweringComposition() = default;
+ TestVectorContractLoweringComposition(
+ const TestVectorContractLoweringComposition &pass)
+ : PassWrapper(pass) {}
+
+ StringRef getArgument() const final {
+ return "test-vector-contract-lowering-composition";
+ }
+
+ StringRef getDescription() const final {
+ return "Test composable vector.contract lowering pattern population.";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ if (mode == "composed") {
+ populateVectorContractToDotPatterns(
+ patterns,
+ [](vector::ContractionOp op) {
+ return success(!parentFunctionNameContains(op, "dot_reject"));
+ },
+ PatternBenefit(3));
+ populateVectorContractToOuterProductPatterns(
+ patterns, defaultFilter, PatternBenefit(2));
+ populateVectorContractGenericLoweringPatterns(
+ patterns, defaultFilter, PatternBenefit(1));
+ } else if (mode == "generic") {
+ populateVectorContractGenericLoweringPatterns(patterns);
+ } else if (mode == "parallel-arith") {
+ populateVectorContractToParallelArithPatterns(patterns);
+ } else if (mode == "parallel-arith-reject") {
+ populateVectorContractToParallelArithPatterns(
+ patterns, [](vector::ContractionOp) { return failure(); });
+ } else {
+ getOperation().emitError()
+ << "unknown contract lowering test mode: " << mode;
+ return signalPassFailure();
+ }
+
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+
+ Option<std::string> mode{
+ *this, "mode",
+ llvm::cl::desc("Contract lowering composition mode to test"),
+ llvm::cl::init("dot-outerproduct")};
+};
+
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns,
OperationPass<func::FuncOp>> {
@@ -1053,6 +1119,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorContractionPrepareForMMTLowering>();
+ PassRegistration<TestVectorContractLoweringComposition>();
+
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns>();
>From 33d30f469b3e1b5b94980463bd8d579a52b7e697 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:41:29 -0700
Subject: [PATCH 2/4] nit
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 4 ----
1 file changed, 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1c98b364d7e0f..5bcab25bf0bbc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -50,10 +50,6 @@ namespace vector {
/// Progressively lower a `vector.contract` with row-major matmul semantics to
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
-/// A `FilterConstraintType` lets clients compose multiple lowering strategies
-/// by benefit. Returning failure means this strategy silently declines the op
-/// without consuming it or diagnosing invalid IR; lower-benefit strategies may
-/// still match the same op.
using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
LogicalResult defaultFilter(ContractionOp op);
>From 129021fa523549c7c1388bcdd6b224945f873c20 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:43:06 -0700
Subject: [PATCH 3/4] format
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
.../Vector/Transforms/LoweringPatterns.h | 12 ++++--------
.../Vector/Transforms/LowerVectorContract.cpp | 19 +++++++------------
2 files changed, 11 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 5bcab25bf0bbc..b3aa86ec7eeee 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -55,23 +55,19 @@ using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
LogicalResult defaultFilter(ContractionOp op);
void populateVectorContractToDotPatterns(
- RewritePatternSet &patterns,
- FilterConstraintType filter = defaultFilter,
+ RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
PatternBenefit benefit = 1);
void populateVectorContractToOuterProductPatterns(
- RewritePatternSet &patterns,
- FilterConstraintType filter = defaultFilter,
+ RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
PatternBenefit benefit = 1);
void populateVectorContractToParallelArithPatterns(
- RewritePatternSet &patterns,
- FilterConstraintType filter = defaultFilter,
+ RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
PatternBenefit benefit = 1);
void populateVectorContractGenericLoweringPatterns(
- RewritePatternSet &patterns,
- FilterConstraintType filter = defaultFilter,
+ RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
PatternBenefit benefit = 1);
void populateVectorContractLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 9ae75fc5423f8..7ed4731f2f2c3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -301,10 +301,9 @@ class ContractionOpGenericLowering
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- ContractionOpGenericLowering(
- MLIRContext *context,
- FilterConstraintType constraint = defaultFilter,
- PatternBenefit benefit = 1)
+ ContractionOpGenericLowering(MLIRContext *context,
+ FilterConstraintType constraint = defaultFilter,
+ PatternBenefit benefit = 1)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
filter(std::move(constraint)) {}
@@ -751,10 +750,9 @@ struct ContractOpToElementwise
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
- ContractOpToElementwise(
- vector::VectorContractLowering vectorContractLowering,
- MLIRContext *context, PatternBenefit benefit = 1,
- FilterConstraintType constraint = defaultFilter)
+ ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering,
+ MLIRContext *context, PatternBenefit benefit = 1,
+ FilterConstraintType constraint = defaultFilter)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorContractLowering(vectorContractLowering),
filter(std::move(constraint)) {}
@@ -1251,10 +1249,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
} // namespace
-LogicalResult
-mlir::vector::defaultFilter(ContractionOp) {
- return success();
-}
+LogicalResult mlir::vector::defaultFilter(ContractionOp) { return success(); }
void mlir::vector::populateVectorContractToDotPatterns(
RewritePatternSet &patterns, FilterConstraintType filter,
>From cb749c8b9b0eac6e54fd6d68c802115f2f77236c Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:44:08 -0700
Subject: [PATCH 4/4] format
Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 13f13b7fba60a..9fcc6fe06548c 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,10 +178,10 @@ struct TestVectorContractLoweringComposition final
return success(!parentFunctionNameContains(op, "dot_reject"));
},
PatternBenefit(3));
- populateVectorContractToOuterProductPatterns(
- patterns, defaultFilter, PatternBenefit(2));
- populateVectorContractGenericLoweringPatterns(
- patterns, defaultFilter, PatternBenefit(1));
+ populateVectorContractToOuterProductPatterns(patterns, defaultFilter,
+ PatternBenefit(2));
+ populateVectorContractGenericLoweringPatterns(patterns, defaultFilter,
+ PatternBenefit(1));
} else if (mode == "generic") {
populateVectorContractGenericLoweringPatterns(patterns);
} else if (mode == "parallel-arith") {
More information about the Mlir-commits
mailing list