[Mlir-commits] [mlir] a480d75 - [mlir][vector] Fold transpose(broadcast(<scalar>))
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 1 11:51:46 PDT 2022
Author: Lei Zhang
Date: 2022-04-01T14:51:36-04:00
New Revision: a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35
URL: https://github.com/llvm/llvm-project/commit/a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35
DIFF: https://github.com/llvm/llvm-project/commit/a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35.diff
LOG: [mlir][vector] Fold transpose(broadcast(<scalar>))
For such cases, the transpose op can be elided.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D122903
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db8d40ae19daf..c7d60e5e3aab6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4212,11 +4212,33 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
+// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+struct FoldTransposedScalarBroadcast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
+ if (!srcVectorType || srcVectorType.getNumElements() == 1) {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transposeOp, transposeOp.getResultType(), bcastOp.getSource());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<TransposeFolder>(context);
+ results.add<FoldTransposedScalarBroadcast, TransposeFolder>(context);
}
void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2b02dc143b4e3..13022c29cd4e8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1348,3 +1348,27 @@ func @shuffle_nofold2(%v0 : vector<[4]xi32>, %v1 : vector<[2]xi32>) -> vector<4x
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32>
return %shuffle : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast1
+// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
+// CHECK: return %[[V]] : vector<1x8xf32>
+func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
+ %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
+ %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+ return %t : vector<1x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast2
+// CHECK-SAME: (%[[ARG:.+]]: f32)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
+// CHECK: return %[[V]] : vector<1x8xf32>
+func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
+ %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
+ %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+ return %t : vector<1x8xf32>
+}
More information about the Mlir-commits
mailing list