[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)

Han-Chung Wang llvmlistbot at llvm.org
Thu Jun 20 09:57:17 PDT 2024


================
@@ -1795,6 +1795,87 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
+/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
+/// B)`.
+/// Example:
+///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///
+/// Becomes :
+///
+///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///
+/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+template <typename MulOpType>
+struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
+  using OpRewritePattern<MulOpType>::OpRewritePattern;
+  // Returns whether a vector.broadcast matches requirements for an outerproduct
+  // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
+  bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+    // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
+    // shape_casts/broadcasts which does not belong in this pattern.
+    if (!broadcastOp.computeBroadcastedUnitDims().empty())
+      return false;
+    // Avoid broadcast like f32 or vector<f32> -> ResType
+    auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcType || srcType.getRank() == 2)
+      return false;
+    return true;
+  }
+
+  LogicalResult matchAndRewrite(MulOpType mulOp,
+                                PatternRewriter &rewriter) const override {
+    auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+    if (!resType)
+      return failure();
+    if (resType.getRank() != 2)
+      return failure();
+    /// If operandA can be written as tr(broadcast(A)) and operandB as
+    /// broadcast(B) where broadcasts are 1D-to-2D, create and return
+    /// vector.outerproduct(A, B). Returns failure() otherwise.
+    auto matchOuterProduct =
+        [&](Value operandA,
+            Value operandB) -> FailureOr<vector::OuterProductOp> {
+      vector::TransposeOp transposedLhs =
+          dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp());
----------------
hanhanW wrote:

nit: use auto because the casting already spells the type; we can use `operandA.getDefiningOp<>(vector::TransposeOp)`.

https://www.llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

> Don’t “almost always” use auto, but do use auto with initializers like cast<Foo>(...) or other places where the type is already obvious from the context.

https://github.com/llvm/llvm-project/pull/93664


More information about the Mlir-commits mailing list