[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)

Hugo Trachino llvmlistbot at llvm.org
Fri Jun 21 00:52:03 PDT 2024


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/93664

>From dab57eb7626b30cacacf8a63e4707457b6788f7e Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 20 May 2024 18:58:54 +0800
Subject: [PATCH 1/6] [mlir][vector] Add ElementwiseToOuterproduct

---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  4 +
 .../Vector/TransformOps/VectorTransformOps.td | 11 +++
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Vector/Transforms/VectorTransforms.cpp    | 75 +++++++++++++++++++
 .../test/Dialect/Vector/transform-vector.mlir | 38 ++++++++++
 5 files changed, 133 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa..ac55433fadb2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
 /// into vector contract for the backends with native support.
 void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns that fold elementwise op on vectors to the vector
+/// dialect.
+void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..e1da09fba73a7 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.elementwise_to_vector",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns that fold elementwise op on vectors to the vector 
+    dialect.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.reduction_to_contract",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..6e13749a66415 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
   vector::populateFoldArithExtensionPatterns(patterns);
 }
 
+void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateElementwiseToVectorOpsPatterns(patterns);
+}
+
 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..d7ccfc4986068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
+template <typename MulOpType>
+struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+  using OpRewritePattern<MulOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOpType mulOp,
+                                PatternRewriter &rewriter) const override {
+    auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
+    if (!VT)
+      return failure();
+    if (VT.getRank() != 2)
+      return failure();
+
+    auto canonicalize = [&](Value OperandA,
+                            Value OperandB) -> vector::OuterProductOp {
+      vector::TransposeOp transposedLhs =
+          dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+      if (!transposedLhs)
+        return vector::OuterProductOp();
+      // Fail unless this is a true 2-D matrix transpose.
+      ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+      if (permutation[0] != 1 || permutation[1] != 0)
+        return vector::OuterProductOp();
+
+      // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
+      // generating shape_casts/broadcasts which do not belong in this pattern.
+      vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
+          transposedLhs.getVector().getDefiningOp());
+      if (!broadcastedLhs ||
+          !broadcastedLhs.computeBroadcastedUnitDims().empty())
+        return vector::OuterProductOp();
+      // Avoid broadcast f32 or vector<f32> -> ResType
+      auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
+      if (!srcVT || srcVT.getRank() != 1)
+        return vector::OuterProductOp();
+
+      vector::BroadcastOp broadcastedRhs =
+          dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
+      if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+        return vector::OuterProductOp();
+
+      return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
+          mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
+          Value(), vector::CombiningKind::ADD);
+    };
+    Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+    vector::OuterProductOp outerP = canonicalize(a, b);
+    // Handle commutativity, the transposed op is the outerproduct LHS.
+    outerP = outerP ? outerP : canonicalize(b, a);
+    return outerP ? success() : failure();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
                                          maxNumElementsToExtract, benefit);
 }
 
+void mlir::vector::populateElementwiseToVectorOpsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
+               ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 75b29e22b4d2c..c170486f6ce27 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @ewise_outerproduct
+//  CHECK-SAME:   %[[LHS:.*]]: vector<[4]xi32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+//       CHECK:     return %[[RES]] : vector<[4]x[4]xi32>
+func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+  %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+  return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+//  CHECK-SAME:   %[[LHS:.*]]: vector<16xf32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
+//       CHECK:     return %[[RES]] : vector<16x16xf32>
+func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
+  %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
+  %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
+  %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
+  return %mul: vector<16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.elementwise_to_vector
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 0f454b045f27c27c57601454854c8add7e147fb3 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 30 May 2024 17:30:03 +0800
Subject: [PATCH 2/6] Add support for different sizes rhs/lhs

---
 .../Vector/Transforms/VectorTransforms.cpp    | 23 +++++++++++--------
 .../test/Dialect/Vector/transform-vector.mlir | 13 +++++++++++
 2 files changed, 27 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d7ccfc4986068..a48101699c4f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1815,6 +1815,18 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 template <typename MulOpType>
 struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
   using OpRewritePattern<MulOpType>::OpRewritePattern;
+  // Helper function returning the source of the input broadcast if it matches requirements for an outerproduct pattern.
+  Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+    // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
+    // shape_casts/broadcasts which does not belong in this pattern.
+    if (!broadcastOp.computeBroadcastedUnitDims().empty())
+      return Value();
+    // Avoid broadcast like f32 or vector<f32> -> ResType
+    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVT || srcVT.getRank() != 1)
+      return Value();
+    return broadcastOp.getSource();
+  }
 
   LogicalResult matchAndRewrite(MulOpType mulOp,
                                 PatternRewriter &rewriter) const override {
@@ -1835,21 +1847,14 @@ struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
       if (permutation[0] != 1 || permutation[1] != 0)
         return vector::OuterProductOp();
 
-      // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
-      // generating shape_casts/broadcasts which do not belong in this pattern.
       vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
           transposedLhs.getVector().getDefiningOp());
-      if (!broadcastedLhs ||
-          !broadcastedLhs.computeBroadcastedUnitDims().empty())
-        return vector::OuterProductOp();
-      // Avoid broadcast f32 or vector<f32> -> ResType
-      auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
-      if (!srcVT || srcVT.getRank() != 1)
+      if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs))
         return vector::OuterProductOp();
 
       vector::BroadcastOp broadcastedRhs =
           dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
-      if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+      if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs))
         return vector::OuterProductOp();
 
       return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index c170486f6ce27..783deb276f3cc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -121,6 +121,19 @@ func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<
   return %mul: vector<16x16xf32>
 }
 
+// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes
+//  CHECK-SAME:   %[[LHS:.*]]: vector<8xf32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32>
+//       CHECK:     return %[[RES]] : vector<8x4xf32>
+func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> {
+  %lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32>
+  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+  %rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32>
+  %mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32>
+  return %mul: vector<8x4xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op

>From 9889dc21ea2fcb951dc97b585f24b4831206cb29 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 30 May 2024 17:37:13 +0800
Subject: [PATCH 3/6] fix formattign

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a48101699c4f7..0bbdffeb5d9a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1815,7 +1815,8 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 template <typename MulOpType>
 struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
   using OpRewritePattern<MulOpType>::OpRewritePattern;
-  // Helper function returning the source of the input broadcast if it matches requirements for an outerproduct pattern.
+  // Helper function returning the source of the input broadcast if it matches
+  // requirements for an outerproduct pattern.
   Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
     // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
     // shape_casts/broadcasts which does not belong in this pattern.

>From de63fd6f2f54867e3280c76808f984e4ec2ca17e Mon Sep 17 00:00:00 2001
From: Hugo Trachino <hugo.trachino at huawei.com>
Date: Mon, 3 Jun 2024 09:23:19 +0100
Subject: [PATCH 4/6] Update
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Co-authored-by: Han-Chung Wang <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 0bbdffeb5d9a6..c7874d4506892 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1797,21 +1797,20 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 
 /// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
 /// into vector.outerproduct(A, B) such as :
-/// ```mlir
+///
 ///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
 ///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
 ///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
 ///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
-///```
+///
 /// Becomes :
-///```mlir
+///
 ///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
-///```
+///
 /// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
-
 template <typename MulOpType>
 struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
   using OpRewritePattern<MulOpType>::OpRewritePattern;

>From aa165d2715046e3eb903d87f0cda95b2e0facc2c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 3 Jun 2024 19:13:23 +0800
Subject: [PATCH 5/6] fix review comments

---
 .../Vector/Transforms/VectorTransforms.cpp    | 94 ++++++++++---------
 .../test/Dialect/Vector/transform-vector.mlir | 37 +++-----
 2 files changed, 63 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c7874d4506892..827c789df7f50 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1795,9 +1795,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
-/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
-/// into vector.outerproduct(A, B) such as :
-///
+/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
+/// B)`.
+/// Example:
 ///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
 ///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
 ///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
@@ -1807,65 +1807,72 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 ///
 ///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
 ///
-/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
 /// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
 /// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
 /// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
 template <typename MulOpType>
-struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
   using OpRewritePattern<MulOpType>::OpRewritePattern;
-  // Helper function returning the source of the input broadcast if it matches
-  // requirements for an outerproduct pattern.
-  Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+  // Returns whether a vector.broadcast matches requirements for an outerproduct
+  // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
+  bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
     // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
     // shape_casts/broadcasts which does not belong in this pattern.
     if (!broadcastOp.computeBroadcastedUnitDims().empty())
-      return Value();
+      return false;
     // Avoid broadcast like f32 or vector<f32> -> ResType
-    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
-    if (!srcVT || srcVT.getRank() != 1)
-      return Value();
-    return broadcastOp.getSource();
+    auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcType || srcType.getRank() == 2)
+      return false;
+    return true;
   }
 
   LogicalResult matchAndRewrite(MulOpType mulOp,
                                 PatternRewriter &rewriter) const override {
-    auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
-    if (!VT)
+    auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+    if (!resType)
       return failure();
-    if (VT.getRank() != 2)
+    if (resType.getRank() != 2)
       return failure();
-
-    auto canonicalize = [&](Value OperandA,
-                            Value OperandB) -> vector::OuterProductOp {
+    /// If operandA can be written as tr(broadcast(A)) and operandB as
+    /// broadcast(B) where broadcasts are 1D-to-2D, create and return
+    /// vector.outerproduct(A, B). Returns failure() otherwise.
+    auto matchOuterProduct =
+        [&](Value operandA,
+            Value operandB) -> FailureOr<vector::OuterProductOp> {
       vector::TransposeOp transposedLhs =
-          dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+          dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp());
       if (!transposedLhs)
-        return vector::OuterProductOp();
+        return failure();
       // Fail unless this is a true 2-D matrix transpose.
       ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
-      if (permutation[0] != 1 || permutation[1] != 0)
-        return vector::OuterProductOp();
-
-      vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
-          transposedLhs.getVector().getDefiningOp());
-      if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs))
-        return vector::OuterProductOp();
-
-      vector::BroadcastOp broadcastedRhs =
-          dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
-      if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs))
-        return vector::OuterProductOp();
-
-      return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
-          mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
-          Value(), vector::CombiningKind::ADD);
+      if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
+        return failure();
+
+      auto broadcastedLhs =
+          transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
+      if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
+        return failure();
+
+      auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
+      if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
+        return failure();
+
+      return rewriter.create<vector::OuterProductOp>(
+          mulOp->getLoc(), resType, broadcastedLhs.getSource(),
+          broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
     };
-    Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
-    vector::OuterProductOp outerP = canonicalize(a, b);
+
+    Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
+    auto maybeOuterP = matchOuterProduct(lhs, rhs);
     // Handle commutativity, the transposed op is the outerproduct LHS.
-    outerP = outerP ? outerP : canonicalize(b, a);
-    return outerP ? success() : failure();
+    if (failed(maybeOuterP))
+      maybeOuterP = matchOuterProduct(rhs, lhs);
+    if (failed(maybeOuterP))
+      return failure();
+    rewriter.replaceOp(mulOp, maybeOuterP->getResult());
+    return success();
   }
 };
 
@@ -1958,8 +1965,9 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
 
 void mlir::vector::populateElementwiseToVectorOpsPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
-               ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
+  patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
+               FoldArithToVectorOuterProduct<arith::MulIOp>>(
+      patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 783deb276f3cc..4b38db79bff3e 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -95,12 +95,12 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func.func @ewise_outerproduct
+// CHECK-LABEL: func.func @arith_to_outerproduct_scalable_i32
 //  CHECK-SAME:   %[[LHS:.*]]: vector<[4]xi32>,
 //  CHECK-SAME:   %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
 //       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
 //       CHECK:     return %[[RES]] : vector<[4]x[4]xi32>
-func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+func.func @arith_to_outerproduct_scalable_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
   %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
   %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
   %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
@@ -108,30 +108,17 @@ func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> v
   return %mul: vector<[4]x[4]xi32>
 }
 
-// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+// CHECK-LABEL: func.func @arith_to_outerproduct_trans_rhs_f32
 //  CHECK-SAME:   %[[LHS:.*]]: vector<16xf32>,
-//  CHECK-SAME:   %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
-//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
-//       CHECK:     return %[[RES]] : vector<16x16xf32>
-func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
-  %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
-  %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
-  %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
-  %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
-  return %mul: vector<16x16xf32>
-}
-
-// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes
-//  CHECK-SAME:   %[[LHS:.*]]: vector<8xf32>,
-//  CHECK-SAME:   %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> {
-//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32>
-//       CHECK:     return %[[RES]] : vector<8x4xf32>
-func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> {
-  %lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32>
-  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
-  %rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32>
-  %mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32>
-  return %mul: vector<8x4xf32>
+//  CHECK-SAME:   %[[RHS:.*]]: vector<8xf32>) -> vector<8x16xf32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<8xf32>, vector<16xf32>
+//       CHECK:     return %[[RES]] : vector<8x16xf32>
+func.func @arith_to_outerproduct_trans_rhs_f32(%lhs: vector<16xf32>, %rhs: vector<8xf32>) -> vector<8x16xf32> {
+  %rhsBcast = vector.broadcast %rhs : vector<8xf32> to vector<16x8xf32>
+  %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+  %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<8x16xf32>
+  %mul = arith.mulf %lhsBcast, %rhsT : vector<8x16xf32>
+  return %mul: vector<8x16xf32>
 }
 
 module attributes {transform.with_named_sequence} {

>From ec3cd832d96f535bca33028dc5dba516f3577bd6 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 21 Jun 2024 15:51:30 +0800
Subject: [PATCH 6/6] fixup : coding style improvements (nfc)

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 827c789df7f50..1d124261d8eff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1823,9 +1823,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
       return false;
     // Avoid broadcast like f32 or vector<f32> -> ResType
     auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
-    if (!srcType || srcType.getRank() == 2)
-      return false;
-    return true;
+    return srcType && srcType.getRank() != 2;
   }
 
   LogicalResult matchAndRewrite(MulOpType mulOp,
@@ -1841,8 +1839,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
     auto matchOuterProduct =
         [&](Value operandA,
             Value operandB) -> FailureOr<vector::OuterProductOp> {
-      vector::TransposeOp transposedLhs =
-          dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp());
+      auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
       if (!transposedLhs)
         return failure();
       // Fail unless this is a true 2-D matrix transpose.



More information about the Mlir-commits mailing list