[Mlir-commits] [mlir] [mlir][vector] Add ElementwiseToOuterproduct (PR #93664)
Hugo Trachino
llvmlistbot at llvm.org
Thu May 30 02:32:44 PDT 2024
https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/93664
>From dab57eb7626b30cacacf8a63e4707457b6788f7e Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 20 May 2024 18:58:54 +0800
Subject: [PATCH 1/2] [mlir][vector] Add ElementwiseToOuterproduct
---
.../mlir/Dialect/Vector/IR/VectorOps.h | 4 +
.../Vector/TransformOps/VectorTransformOps.td | 11 +++
.../TransformOps/VectorTransformOps.cpp | 5 ++
.../Vector/Transforms/VectorTransforms.cpp | 75 +++++++++++++++++++
.../test/Dialect/Vector/transform-vector.mlir | 38 ++++++++++
5 files changed, 133 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa..ac55433fadb2f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns that fold elementwise op on vectors to the vector
+/// dialect.
+void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..e1da09fba73a7 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.elementwise_to_vector",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns that fold elementwise op on vectors to the vector
+ dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.reduction_to_contract",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..6e13749a66415 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}
+void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateElementwiseToVectorOpsPatterns(patterns);
+}
+
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index f29eba90c3ceb..d7ccfc4986068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
unsigned maxNumElementsToExtract = 0;
};
+/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
+/// into vector.outerproduct(A, B) such as :
+/// ```mlir
+/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
+/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
+/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
+/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
+///```
+/// Becomes :
+///```mlir
+/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
+///```
+/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
+/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
+/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
+/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
+
+template <typename MulOpType>
+struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
+ using OpRewritePattern<MulOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOpType mulOp,
+ PatternRewriter &rewriter) const override {
+ auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
+ if (!VT)
+ return failure();
+ if (VT.getRank() != 2)
+ return failure();
+
+ auto canonicalize = [&](Value OperandA,
+ Value OperandB) -> vector::OuterProductOp {
+ vector::TransposeOp transposedLhs =
+ dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
+ if (!transposedLhs)
+ return vector::OuterProductOp();
+ // Fail unless this is a true 2-D matrix transpose.
+ ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
+ if (permutation[0] != 1 || permutation[1] != 0)
+ return vector::OuterProductOp();
+
+ // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
+ // generating shape_casts/broadcasts which do not belong in this pattern.
+ vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
+ transposedLhs.getVector().getDefiningOp());
+ if (!broadcastedLhs ||
+ !broadcastedLhs.computeBroadcastedUnitDims().empty())
+ return vector::OuterProductOp();
+ // Avoid broadcast f32 or vector<f32> -> ResType
+ auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
+ if (!srcVT || srcVT.getRank() != 1)
+ return vector::OuterProductOp();
+
+ vector::BroadcastOp broadcastedRhs =
+ dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
+ if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+ return vector::OuterProductOp();
+
+ return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
+ mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
+ Value(), vector::CombiningKind::ADD);
+ };
+ Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+ vector::OuterProductOp outerP = canonicalize(a, b);
+ // Handle commutativity, the transposed op is the outerproduct LHS.
+ outerP = outerP ? outerP : canonicalize(b, a);
+ return outerP ? success() : failure();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
maxNumElementsToExtract, benefit);
}
+void mlir::vector::populateElementwiseToVectorOpsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
+ ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 75b29e22b4d2c..c170486f6ce27 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @ewise_outerproduct
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
+// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
+func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
+ %lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+ %rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
+ %mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
+ return %mul: vector<[4]x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
+// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
+// CHECK: return %[[RES]] : vector<16x16xf32>
+func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
+ %rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
+ %rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+ %lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
+ %mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
+ return %mul: vector<16x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.elementwise_to_vector
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From 0f454b045f27c27c57601454854c8add7e147fb3 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 30 May 2024 17:30:03 +0800
Subject: [PATCH 2/2] Add support for different sizes rhs/lhs
---
.../Vector/Transforms/VectorTransforms.cpp | 23 +++++++++++--------
.../test/Dialect/Vector/transform-vector.mlir | 13 +++++++++++
2 files changed, 27 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d7ccfc4986068..a48101699c4f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1815,6 +1815,18 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
template <typename MulOpType>
struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
using OpRewritePattern<MulOpType>::OpRewritePattern;
+ // Helper function returning the source of the input broadcast if it matches requirements for an outerproduct pattern.
+ Value getValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
+ // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
+ // shape_casts/broadcasts which does not belong in this pattern.
+ if (!broadcastOp.computeBroadcastedUnitDims().empty())
+ return Value();
+ // Avoid broadcast like f32 or vector<f32> -> ResType
+ auto srcVT = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVT || srcVT.getRank() != 1)
+ return Value();
+ return broadcastOp.getSource();
+ }
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
@@ -1835,21 +1847,14 @@ struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
if (permutation[0] != 1 || permutation[1] != 0)
return vector::OuterProductOp();
- // Fail in case it is not a 1-to-2 dimension to broadcast to avoid
- // generating shape_casts/broadcasts which do not belong in this pattern.
vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
transposedLhs.getVector().getDefiningOp());
- if (!broadcastedLhs ||
- !broadcastedLhs.computeBroadcastedUnitDims().empty())
- return vector::OuterProductOp();
- // Avoid broadcast f32 or vector<f32> -> ResType
- auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
- if (!srcVT || srcVT.getRank() != 1)
+ if (!broadcastedLhs || !getValidBroadcastSource(broadcastedLhs))
return vector::OuterProductOp();
vector::BroadcastOp broadcastedRhs =
dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
- if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
+ if (!broadcastedRhs || !getValidBroadcastSource(broadcastedRhs))
return vector::OuterProductOp();
return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index c170486f6ce27..783deb276f3cc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -121,6 +121,19 @@ func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<
return %mul: vector<16x16xf32>
}
+// CHECK-LABEL: func.func @ewise_outerproduct_different_sizes
+// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<4xf32>) -> vector<8x4xf32> {
+// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<8xf32>, vector<4xf32>
+// CHECK: return %[[RES]] : vector<8x4xf32>
+func.func @ewise_outerproduct_different_sizes(%lhs: vector<8xf32>, %rhs: vector<4xf32>) -> vector<8x4xf32> {
+ %lhsBcast = vector.broadcast %lhs : vector<8xf32> to vector<4x8xf32>
+ %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
+ %rhsBcast = vector.broadcast %rhs : vector<4xf32> to vector<8x4xf32>
+ %mul = arith.mulf %lhsT, %rhsBcast : vector<8x4xf32>
+ return %mul: vector<8x4xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
More information about the Mlir-commits
mailing list