[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)

Benjamin Maxwell llvmlistbot at llvm.org
Wed May 29 09:38:35 PDT 2024


================
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @ewise_outerproduct
+//  CHECK-SAME:   %[[LHS:.*]]: vector<[4]xi32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+//       CHECK:     return %[[RES]] : vector<[4]x[4]xi32>
+func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+  %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+  %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+  return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+//  CHECK-SAME:   %[[LHS:.*]]: vector<16xf32>,
+//  CHECK-SAME:   %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
+//       CHECK:     %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
+//       CHECK:     return %[[RES]] : vector<16x16xf32>
+func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
+  %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
+  %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
+  %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
+  return %mul: vector<16x16xf32>
+}
+
----------------
MacDue wrote:

Can this pattern handle the case where the lhs and rhs have different sizes? e.g.:
```
%lhsBcast = vector.broadcast %lhs : vector<8xi32> to vector<4x8xi32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xi32> to vector<8x4xi32> 
%rhsBcast = vector.broadcast %rhs : vector<4xi32> to vector<8x4xi32> 
%mul = arith.muli %lhsT, %rhsBcast : vector<8x4xi32>
```

->
```
vector.outerproduct %lhs, %rhs : vector<8xi32>, vector<4xi32>
```

If so, it'd be nice to have some tests for that.

https://github.com/llvm/llvm-project/pull/93664


More information about the Mlir-commits mailing list