[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)

Han-Chung Wang llvmlistbot at llvm.org
Thu May 30 16:51:31 PDT 2024


================
@@ -1795,6 +1795,81 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
+template <typename MulOpType>
+struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+  using OpRewritePattern<MulOpType>::OpRewritePattern;
+  // Helper function returning the source of the input broadcast if it matches
+  // requirements for an outerproduct pattern.
+  Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+    // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
+    // shape_casts/broadcasts which does not belong in this pattern.
+    if (!broadcastOp.computeBroadcastedUnitDims().empty())
+      return Value();
+    // Avoid broadcast like f32 or vector<f32> -> ResType
+    auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVT || srcVT.getRank() != 1)
+      return Value();
+    return broadcastOp.getSource();
+  }
+
+  LogicalResult matchAndRewrite(MulOpType mulOp,
+                                PatternRewriter &rewriter) const override {
+    auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
+    if (!VT)
+      return failure();
+    if (VT.getRank() != 2)
+      return failure();
+
+    auto canonicalize = [&](Value OperandA,
+                            Value OperandB) -> vector::OuterProductOp {
+      vector::TransposeOp transposedLhs =
+          dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+      if (!transposedLhs)
+        return vector::OuterProductOp();
+      // Fail unless this is a true 2-D matrix transpose.
+      ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+      if (permutation[0] != 1 || permutation[1] != 0)
+        return vector::OuterProductOp();
+
+      vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
+          transposedLhs.getVector().getDefiningOp());
+      if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs))
+        return vector::OuterProductOp();
+
+      vector::BroadcastOp broadcastedRhs =
+          dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
+      if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs))
+        return vector::OuterProductOp();
+
+      return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
+          mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
+          Value(), vector::CombiningKind::ADD);
----------------
hanhanW wrote:

I'd suggest just return the created op; replace the old op with the new op at the end. It is more common in MLIR patterns.

https://github.com/llvm/llvm-project/pull/93664


More information about the Mlir-commits mailing list