[Mlir-commits] [mlir] [draft] publicize vec contract lowering strategy (PR #196642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 8 13:47:43 PDT 2026


https://github.com/efric created https://github.com/llvm/llvm-project/pull/196642

None

>From 50fe986c5e677d8ddac90b334a9bb1003e73c971 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Mon, 27 Apr 2026 21:35:05 -0700
Subject: [PATCH 1/4] experimental expose vector contract lowerings to have
 multiple options

Signed-off-by: Eric Feng <Eric.Feng at amd.com>

add more tests and parallelarith

Signed-off-by: Eric Feng <Eric.Feng at amd.com>

Keep contract lowering filter patch focused

Reuse filter constraint naming for contract lowering policy

Rename vector contract default filter
---
 .../Vector/Transforms/LoweringPatterns.h      |  31 ++++
 .../Vector/Transforms/LowerVectorContract.cpp | 154 ++++++++++++------
 .../vector-contract-composable-lowering.mlir  | 118 ++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  68 ++++++++
 4 files changed, 323 insertions(+), 48 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..1c98b364d7e0f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <functional>
+
 namespace mlir {
 class RewritePatternSet;
 
@@ -47,6 +49,35 @@ namespace vector {
 /// [ContractionOpToOuterProductOpLowering]
 /// Progressively lower a `vector.contract` with row-major matmul semantics to
 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
+
+/// A `FilterConstraintType` lets clients compose multiple lowering strategies
+/// by benefit. Returning failure means this strategy silently declines the op
+/// without consuming it or diagnosing invalid IR; lower-benefit strategies may
+/// still match the same op.
+using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
+
+LogicalResult defaultFilter(ContractionOp op);
+
+void populateVectorContractToDotPatterns(
+    RewritePatternSet &patterns,
+    FilterConstraintType filter = defaultFilter,
+    PatternBenefit benefit = 1);
+
+void populateVectorContractToOuterProductPatterns(
+    RewritePatternSet &patterns,
+    FilterConstraintType filter = defaultFilter,
+    PatternBenefit benefit = 1);
+
+void populateVectorContractToParallelArithPatterns(
+    RewritePatternSet &patterns,
+    FilterConstraintType filter = defaultFilter,
+    PatternBenefit benefit = 1);
+
+void populateVectorContractGenericLoweringPatterns(
+    RewritePatternSet &patterns,
+    FilterConstraintType filter = defaultFilter,
+    PatternBenefit benefit = 1);
+
 void populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
     VectorContractLowering vectorContractLoweringOption,
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index eaf7bb8109514..9ae75fc5423f8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -223,13 +223,6 @@ class ContractionOpToOuterProductOpLowering
 public:
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
   ContractionOpToOuterProductOpLowering(
       vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
@@ -271,19 +264,13 @@ class ContractionOpToDotLowering
 public:
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
-
   ContractionOpToDotLowering(
       vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
-      const FilterConstraintType &constraint = defaultFilter)
+      FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering),
+        filter(std::move(constraint)) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -309,33 +296,34 @@ class ContractionOpToDotLowering
 ///
 /// This only kicks in when either VectorTransformsOptions is set
 /// to Dot or when other contraction patterns fail.
-class ContractionOpLowering
+class ContractionOpGenericLowering
     : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
 
-  ContractionOpLowering(
-      vector::VectorContractLowering vectorContractLoweringOption,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      FilterConstraintType constraint = defaultFilter)
+  ContractionOpGenericLowering(
+      MLIRContext *context,
+      FilterConstraintType constraint = defaultFilter,
+      PatternBenefit benefit = 1)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorContractLoweringOption(vectorContractLoweringOption),
         filter(std::move(constraint)) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
                             PatternRewriter &rewriter) const override;
 
+protected:
+  LogicalResult
+  matchSupportedGenericContraction(vector::ContractionOp op,
+                                   PatternRewriter &rewriter) const;
+
+  FailureOr<Value> lowerGenericContraction(PatternRewriter &rewriter,
+                                           vector::ContractionOp op,
+                                           MaskingOpInterface maskOp) const;
+
 private:
-  /// Options to control the vector patterns.
-  vector::VectorContractLowering vectorContractLoweringOption;
   FilterConstraintType filter;
+
   // Lower one parallel dimension.
   FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
                                  vector::ContractionOp op, int64_t lhsIndex,
@@ -345,6 +333,24 @@ class ContractionOpLowering
                                   vector::ContractionOp op, Value mask) const;
 };
 
+class ContractionOpLowering : public ContractionOpGenericLowering {
+public:
+  ContractionOpLowering(
+      vector::VectorContractLowering vectorContractLoweringOption,
+      MLIRContext *context, PatternBenefit benefit = 1,
+      FilterConstraintType constraint = defaultFilter)
+      : ContractionOpGenericLowering(context, std::move(constraint), benefit),
+        vectorContractLoweringOption(vectorContractLoweringOption) {}
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
+
+private:
+  /// Options to control the vector patterns.
+  vector::VectorContractLowering vectorContractLoweringOption;
+};
+
 /// Generate a vector implementation for matmat, matvec and tmatvec.
 /// This unrolls outer-products along the reduction dimension.
 struct UnrolledOuterProductGenerator
@@ -744,17 +750,14 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
 struct ContractOpToElementwise
     : public MaskableOpRewritePattern<vector::ContractionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
-  using FilterConstraintType =
-      std::function<LogicalResult(vector::ContractionOp op)>;
-  static LogicalResult defaultFilter(vector::ContractionOp op) {
-    return success();
-  }
+
   ContractOpToElementwise(
       vector::VectorContractLowering vectorContractLowering,
       MLIRContext *context, PatternBenefit benefit = 1,
-      const FilterConstraintType &constraint = defaultFilter)
+      FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
-        vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
+        vectorContractLowering(vectorContractLowering),
+        filter(std::move(constraint)) {}
 
   FailureOr<Value>
   matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -880,9 +883,8 @@ struct ContractOpToElementwise
 // TODO: break down into transpose/reshape/cast ops
 //               when they become available to avoid code dup
 // TODO: investigate lowering order impact on performance
-FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
-    vector::ContractionOp op, MaskingOpInterface maskOp,
-    PatternRewriter &rewriter) const {
+LogicalResult ContractionOpGenericLowering::matchSupportedGenericContraction(
+    vector::ContractionOp op, PatternRewriter &rewriter) const {
   if (failed(filter(op)))
     return failure();
 
@@ -894,10 +896,26 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
 
   // TODO: the code below assumes the default contraction, make sure it supports
   // other kinds before enabling this lowering.
-  if (op.getKind() != vector::CombiningKind::ADD) {
+  if (op.getKind() != vector::CombiningKind::ADD)
     return rewriter.notifyMatchFailure(
         op, "contractions other than 'add' not supported");
-  }
+  return success();
+}
+
+FailureOr<Value> ContractionOpGenericLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
+  if (failed(matchSupportedGenericContraction(op, rewriter)))
+    return failure();
+
+  return lowerGenericContraction(rewriter, op, maskOp);
+}
+
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
+  if (failed(matchSupportedGenericContraction(op, rewriter)))
+    return failure();
 
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
@@ -920,6 +938,12 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
   if (!failed(newVal4))
     return newVal4;
 
+  return lowerGenericContraction(rewriter, op, maskOp);
+}
+
+FailureOr<Value> ContractionOpGenericLowering::lowerGenericContraction(
+    PatternRewriter &rewriter, vector::ContractionOp op,
+    MaskingOpInterface maskOp) const {
   // Vector mask setup.
 
   Value mask;
@@ -982,11 +1006,9 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
 // Lower one parallel dimension.
 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
 // TODO: consider reusing existing contract unrolling
-FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
-                                                      vector::ContractionOp op,
-                                                      int64_t lhsIndex,
-                                                      int64_t rhsIndex,
-                                                      Value mask) const {
+FailureOr<Value> ContractionOpGenericLowering::lowerParallel(
+    PatternRewriter &rewriter, vector::ContractionOp op, int64_t lhsIndex,
+    int64_t rhsIndex, Value mask) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   VectorType resType = cast<VectorType>(op.getResultType());
@@ -1069,7 +1091,7 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
 }
 
 // Lower one reduction dimension.
-FailureOr<Value> ContractionOpLowering::lowerReduction(
+FailureOr<Value> ContractionOpGenericLowering::lowerReduction(
     PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
   auto loc = op.getLoc();
   VectorType lhsType = op.getLhsType();
@@ -1229,6 +1251,42 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 
 } // namespace
 
+LogicalResult
+mlir::vector::defaultFilter(ContractionOp) {
+  return success();
+}
+
+void mlir::vector::populateVectorContractToDotPatterns(
+    RewritePatternSet &patterns, FilterConstraintType filter,
+    PatternBenefit benefit) {
+  patterns.add<ContractionOpToDotLowering>(vector::VectorContractLowering::Dot,
+                                           patterns.getContext(), benefit,
+                                           std::move(filter));
+}
+
+void mlir::vector::populateVectorContractToOuterProductPatterns(
+    RewritePatternSet &patterns, FilterConstraintType filter,
+    PatternBenefit benefit) {
+  patterns.add<ContractionOpToOuterProductOpLowering>(
+      vector::VectorContractLowering::OuterProduct, patterns.getContext(),
+      benefit, std::move(filter));
+}
+
+void mlir::vector::populateVectorContractToParallelArithPatterns(
+    RewritePatternSet &patterns, FilterConstraintType filter,
+    PatternBenefit benefit) {
+  patterns.add<ContractOpToElementwise>(
+      vector::VectorContractLowering::ParallelArith, patterns.getContext(),
+      benefit, std::move(filter));
+}
+
+void mlir::vector::populateVectorContractGenericLoweringPatterns(
+    RewritePatternSet &patterns, FilterConstraintType filter,
+    PatternBenefit benefit) {
+  patterns.add<ContractionOpGenericLowering>(patterns.getContext(),
+                                             std::move(filter), benefit);
+}
+
 void mlir::vector::populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
     VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
diff --git a/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir b/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir
new file mode 100644
index 0000000000000..2f755afb92951
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-composable-lowering.mlir
@@ -0,0 +1,118 @@
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=composed" --split-input-file | FileCheck %s --check-prefix=COMPOSED
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=generic" --split-input-file | FileCheck %s --check-prefix=GENERIC
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=parallel-arith" --split-input-file | FileCheck %s --check-prefix=PARALLEL_ACCEPT
+// RUN: mlir-opt %s --test-vector-contract-lowering-composition="mode=parallel-arith-reject" --split-input-file | FileCheck %s --check-prefix=PARALLEL_REJECT
+
+#matmat_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// COMPOSED-LABEL: func @dot_accept
+// COMPOSED-NOT: vector.outerproduct
+// COMPOSED: vector.reduction <add>
+func.func @dot_accept(%A: vector<2x4xf32>,
+                      %B: vector<4x3xf32>,
+                      %C: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait %A, %B, %C
+    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// COMPOSED-LABEL: func @dot_reject_to_outerproduct
+// COMPOSED: vector.outerproduct
+func.func @dot_reject_to_outerproduct(%A: vector<2x4xf32>,
+                                      %B: vector<4x3xf32>,
+                                      %C: vector<2x3xf32>)
+                                      -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait %A, %B, %C
+    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+#batch_matmul_accesses = [
+  affine_map<(b, m, n, k) -> (b, m, k)>,
+  affine_map<(b, m, n, k) -> (b, k, n)>,
+  affine_map<(b, m, n, k) -> (b, m, n)>
+]
+#batch_matmul_trait = {
+  indexing_maps = #batch_matmul_accesses,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+// COMPOSED-LABEL: func @dot_structural_failure_to_generic
+// COMPOSED-NOT: vector.outerproduct
+// COMPOSED: vector.extract
+// COMPOSED: vector.reduction <add>
+func.func @dot_structural_failure_to_generic(%A: vector<2x2x4xf32>,
+                                             %B: vector<2x4x3xf32>,
+                                             %C: vector<2x2x3xf32>)
+                                             -> vector<2x2x3xf32> {
+  %0 = vector.contract #batch_matmul_trait %A, %B, %C
+    : vector<2x2x4xf32>, vector<2x4x3xf32> into vector<2x2x3xf32>
+  return %0 : vector<2x2x3xf32>
+}
+
+// -----
+
+#dotp_accesses = [
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> (i)>,
+  affine_map<(i) -> ()>
+]
+#dotp_add_trait = {
+  indexing_maps = #dotp_accesses,
+  iterator_types = ["reduction"]
+}
+#dotp_mul_trait = {
+  indexing_maps = #dotp_accesses,
+  iterator_types = ["reduction"],
+  kind = #vector.kind<mul>
+}
+
+// GENERIC-LABEL: func @generic_add
+// GENERIC: arith.mulf
+// GENERIC: vector.reduction <add>
+func.func @generic_add(%A: vector<4xf32>, %B: vector<4xf32>,
+                       %C: f32) -> f32 {
+  %0 = vector.contract #dotp_add_trait %A, %B, %C
+    : vector<4xf32>, vector<4xf32> into f32
+  return %0 : f32
+}
+
+// GENERIC-LABEL: func @generic_non_add
+// GENERIC: vector.contract
+// GENERIC-SAME: kind = #vector.kind<mul>
+func.func @generic_non_add(%A: vector<4xf32>, %B: vector<4xf32>,
+                           %C: f32) -> f32 {
+  %0 = vector.contract #dotp_mul_trait %A, %B, %C
+    : vector<4xf32>, vector<4xf32> into f32
+  return %0 : f32
+}
+
+// -----
+
+// PARALLEL_ACCEPT-LABEL: func @parallel_arith
+// PARALLEL_ACCEPT-NOT: vector.contract
+// PARALLEL_ACCEPT: vector.fma
+// PARALLEL_REJECT-LABEL: func @parallel_arith
+// PARALLEL_REJECT: vector.contract
+func.func @parallel_arith(
+    %A: vector<1x1x4xf32>, %B: vector<1x1x4xf32>,
+    %C: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.contract {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
+      affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
+      affine_map<(d0, d1, d2) -> (d0)>
+    ],
+    iterator_types = ["parallel", "reduction", "reduction"],
+    kind = #vector.kind<add>
+  } %A, %B, %C : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index ff3520a286cc8..13f13b7fba60a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <optional>
+#include <string>
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -138,6 +139,71 @@ struct TestVectorContractionPrepareForMMTLowering
   }
 };
 
+static bool parentFunctionNameContains(vector::ContractionOp op,
+                                       StringRef substring) {
+  if (auto funcOp = op->getParentOfType<func::FuncOp>())
+    return funcOp.getName().contains(substring);
+  return false;
+}
+
+struct TestVectorContractLoweringComposition final
+    : public PassWrapper<TestVectorContractLoweringComposition,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorContractLoweringComposition)
+
+  TestVectorContractLoweringComposition() = default;
+  TestVectorContractLoweringComposition(
+      const TestVectorContractLoweringComposition &pass)
+      : PassWrapper(pass) {}
+
+  StringRef getArgument() const final {
+    return "test-vector-contract-lowering-composition";
+  }
+
+  StringRef getDescription() const final {
+    return "Test composable vector.contract lowering pattern population.";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    if (mode == "composed") {
+      populateVectorContractToDotPatterns(
+          patterns,
+          [](vector::ContractionOp op) {
+            return success(!parentFunctionNameContains(op, "dot_reject"));
+          },
+          PatternBenefit(3));
+      populateVectorContractToOuterProductPatterns(
+          patterns, defaultFilter, PatternBenefit(2));
+      populateVectorContractGenericLoweringPatterns(
+          patterns, defaultFilter, PatternBenefit(1));
+    } else if (mode == "generic") {
+      populateVectorContractGenericLoweringPatterns(patterns);
+    } else if (mode == "parallel-arith") {
+      populateVectorContractToParallelArithPatterns(patterns);
+    } else if (mode == "parallel-arith-reject") {
+      populateVectorContractToParallelArithPatterns(
+          patterns, [](vector::ContractionOp) { return failure(); });
+    } else {
+      getOperation().emitError()
+          << "unknown contract lowering test mode: " << mode;
+      return signalPassFailure();
+    }
+
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+
+  Option<std::string> mode{
+      *this, "mode",
+      llvm::cl::desc("Contract lowering composition mode to test"),
+      llvm::cl::init("dot-outerproduct")};
+};
+
 struct TestVectorUnrollingPatterns
     : public PassWrapper<TestVectorUnrollingPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1053,6 +1119,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorContractionPrepareForMMTLowering>();
 
+  PassRegistration<TestVectorContractLoweringComposition>();
+
   PassRegistration<TestVectorUnrollingPatterns>();
 
   PassRegistration<TestVectorTransferUnrollingPatterns>();

>From 33d30f469b3e1b5b94980463bd8d579a52b7e697 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:41:29 -0700
Subject: [PATCH 2/4] nit

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 1c98b364d7e0f..5bcab25bf0bbc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -50,10 +50,6 @@ namespace vector {
 /// Progressively lower a `vector.contract` with row-major matmul semantics to
 /// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
 
-/// A `FilterConstraintType` lets clients compose multiple lowering strategies
-/// by benefit. Returning failure means this strategy silently declines the op
-/// without consuming it or diagnosing invalid IR; lower-benefit strategies may
-/// still match the same op.
 using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
 
 LogicalResult defaultFilter(ContractionOp op);

>From 129021fa523549c7c1388bcdd6b224945f873c20 Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:43:06 -0700
Subject: [PATCH 3/4] format

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 .../Vector/Transforms/LoweringPatterns.h      | 12 ++++--------
 .../Vector/Transforms/LowerVectorContract.cpp | 19 +++++++------------
 2 files changed, 11 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 5bcab25bf0bbc..b3aa86ec7eeee 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -55,23 +55,19 @@ using FilterConstraintType = std::function<LogicalResult(ContractionOp op)>;
 LogicalResult defaultFilter(ContractionOp op);
 
 void populateVectorContractToDotPatterns(
-    RewritePatternSet &patterns,
-    FilterConstraintType filter = defaultFilter,
+    RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
     PatternBenefit benefit = 1);
 
 void populateVectorContractToOuterProductPatterns(
-    RewritePatternSet &patterns,
-    FilterConstraintType filter = defaultFilter,
+    RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
     PatternBenefit benefit = 1);
 
 void populateVectorContractToParallelArithPatterns(
-    RewritePatternSet &patterns,
-    FilterConstraintType filter = defaultFilter,
+    RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
     PatternBenefit benefit = 1);
 
 void populateVectorContractGenericLoweringPatterns(
-    RewritePatternSet &patterns,
-    FilterConstraintType filter = defaultFilter,
+    RewritePatternSet &patterns, FilterConstraintType filter = defaultFilter,
     PatternBenefit benefit = 1);
 
 void populateVectorContractLoweringPatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 9ae75fc5423f8..7ed4731f2f2c3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -301,10 +301,9 @@ class ContractionOpGenericLowering
 public:
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  ContractionOpGenericLowering(
-      MLIRContext *context,
-      FilterConstraintType constraint = defaultFilter,
-      PatternBenefit benefit = 1)
+  ContractionOpGenericLowering(MLIRContext *context,
+                               FilterConstraintType constraint = defaultFilter,
+                               PatternBenefit benefit = 1)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         filter(std::move(constraint)) {}
 
@@ -751,10 +750,9 @@ struct ContractOpToElementwise
     : public MaskableOpRewritePattern<vector::ContractionOp> {
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  ContractOpToElementwise(
-      vector::VectorContractLowering vectorContractLowering,
-      MLIRContext *context, PatternBenefit benefit = 1,
-      FilterConstraintType constraint = defaultFilter)
+  ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering,
+                          MLIRContext *context, PatternBenefit benefit = 1,
+                          FilterConstraintType constraint = defaultFilter)
       : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorContractLowering(vectorContractLowering),
         filter(std::move(constraint)) {}
@@ -1251,10 +1249,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 
 } // namespace
 
-LogicalResult
-mlir::vector::defaultFilter(ContractionOp) {
-  return success();
-}
+LogicalResult mlir::vector::defaultFilter(ContractionOp) { return success(); }
 
 void mlir::vector::populateVectorContractToDotPatterns(
     RewritePatternSet &patterns, FilterConstraintType filter,

>From cb749c8b9b0eac6e54fd6d68c802115f2f77236c Mon Sep 17 00:00:00 2001
From: Eric Feng <Eric.Feng at amd.com>
Date: Fri, 8 May 2026 13:44:08 -0700
Subject: [PATCH 4/4] format

Signed-off-by: Eric Feng <Eric.Feng at amd.com>
---
 mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 13f13b7fba60a..9fcc6fe06548c 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,10 +178,10 @@ struct TestVectorContractLoweringComposition final
             return success(!parentFunctionNameContains(op, "dot_reject"));
           },
           PatternBenefit(3));
-      populateVectorContractToOuterProductPatterns(
-          patterns, defaultFilter, PatternBenefit(2));
-      populateVectorContractGenericLoweringPatterns(
-          patterns, defaultFilter, PatternBenefit(1));
+      populateVectorContractToOuterProductPatterns(patterns, defaultFilter,
+                                                   PatternBenefit(2));
+      populateVectorContractGenericLoweringPatterns(patterns, defaultFilter,
+                                                    PatternBenefit(1));
     } else if (mode == "generic") {
       populateVectorContractGenericLoweringPatterns(patterns);
     } else if (mode == "parallel-arith") {



More information about the Mlir-commits mailing list