[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)

Han-Chung Wang llvmlistbot at llvm.org
Thu May 30 16:51:30 PDT 2024


================
@@ -1795,6 +1795,81 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
   unsigned maxNumElementsToExtract = 0;
 };
 
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
----------------
hanhanW wrote:

I think removing the ``` is easier to read. And please remove the blank line between comments and the pattern.

```suggestion
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
/// into vector.outerproduct(A, B) such as :
///
///  %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
///  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
///  vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
///  vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
///
/// Becomes :
///
///  %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
///
/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
```

https://github.com/llvm/llvm-project/pull/93664


More information about the Mlir-commits mailing list