[Mlir-commits] [mlir] a480d75 - [mlir][vector] Fold transpose(broadcast(<scalar>))

Lei Zhang llvmlistbot at llvm.org
Fri Apr 1 11:51:46 PDT 2022


Author: Lei Zhang
Date: 2022-04-01T14:51:36-04:00
New Revision: a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35

URL: https://github.com/llvm/llvm-project/commit/a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35
DIFF: https://github.com/llvm/llvm-project/commit/a480d75fe48d6c4e0ab4ae5fbbf719c57f5ced35.diff

LOG: [mlir][vector] Fold transpose(broadcast(<scalar>))

For such cases, the transpose op can be elided.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D122903

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db8d40ae19daf..c7d60e5e3aab6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4212,11 +4212,33 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
+struct FoldTransposedScalarBroadcast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
+    if (!srcVectorType || srcVectorType.getNumElements() == 1) {
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+          transposeOp, transposeOp.getResultType(), bcastOp.getSource());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<TransposeFolder>(context);
+  results.add<FoldTransposedScalarBroadcast, TransposeFolder>(context);
 }
 
 void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2b02dc143b4e3..13022c29cd4e8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1348,3 +1348,27 @@ func @shuffle_nofold2(%v0 : vector<[4]xi32>, %v1 : vector<[2]xi32>) -> vector<4x
   %shuffle = vector.shuffle %v0, %v1 [0, 1, 2, 3] : vector<[4]xi32>, vector<[2]xi32>
   return %shuffle : vector<4xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast1
+//  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x8xf32>
+//       CHECK:   return %[[V]] : vector<1x8xf32>
+func @transpose_scalar_broadcast1(%value: vector<1xf32>) -> vector<1x8xf32> {
+  %bcast = vector.broadcast %value : vector<1xf32> to vector<8x1xf32>
+  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+  return %t : vector<1x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_scalar_broadcast2
+//  CHECK-SAME: (%[[ARG:.+]]: f32)
+//       CHECK:   %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x8xf32>
+//       CHECK:   return %[[V]] : vector<1x8xf32>
+func @transpose_scalar_broadcast2(%value: f32) -> vector<1x8xf32> {
+  %bcast = vector.broadcast %value : f32 to vector<8x1xf32>
+  %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32>
+  return %t : vector<1x8xf32>
+}


        


More information about the Mlir-commits mailing list