[Mlir-commits] [mlir] 5c18ae3 - [MLIR][Tensor] Canonicalize expand/collapse_shape of splat to splat

Rob Suderman llvmlistbot at llvm.org
Wed Jan 4 13:10:32 PST 2023


Author: liqinweng
Date: 2023-01-04T13:07:55-08:00
New Revision: 5c18ae3135d1ff4b9e554480da78bc93e35ef00a

URL: https://github.com/llvm/llvm-project/commit/5c18ae3135d1ff4b9e554480da78bc93e35ef00a
DIFF: https://github.com/llvm/llvm-project/commit/5c18ae3135d1ff4b9e554480da78bc93e35ef00a.diff

LOG: [MLIR][Tensor] Canonicalize expand/collapse_shape of splat to splat

Collapsing / expanding a splatted value can be replaced with a single `tensor.splat` operation. Replace
these cases with a simple `tensor.splat` operation.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D140552

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c87a003fab3bb..cd962456ba424 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1382,6 +1382,24 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
   }
 };
 
+// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
+template <typename TensorReshapeOp>
+class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
+public:
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
+    if (!splatOp)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<tensor::SplatOp>(
+        reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
+    return success();
+  }
+};
+
 /// Reshape of a FromElements can be replaced with a FromElements of the
 /// result type
 template <typename TensorReshapeOp>
@@ -1523,6 +1541,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
               FoldReshapeWithConstant<ExpandShapeOp>,
+              FoldReshapeWithSplat<ExpandShapeOp>,
               FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
               FoldDimOfCollapseShape>(context);
 }
@@ -1533,6 +1552,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
       .add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
            ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
            FoldReshapeWithConstant<CollapseShapeOp>,
+           FoldReshapeWithSplat<CollapseShapeOp>,
            FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
           context);
 }
@@ -1540,6 +1560,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
 }
+
 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
 }

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 2b11a33681679..6267c269ab0b7 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1013,9 +1013,34 @@ func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
 //   CHECK-NOT:   tensor.expand_shape
 //       CHECK:   return %[[CST]]
+// -----
+func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
+  %c0 = tensor.splat %arg : tensor<2x4xf32>
+  %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+      : tensor<2x4xf32> into tensor<2x2x2xf32>
+  return %0 : tensor<2x2x2xf32>
+}
+// CHECK-LABEL: @expand_shape_splat
+// CHECK-SAME:    %[[ARG0:.+]]: f32
+//       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
+//   CHECK-NOT:   tensor.expand_shape
+//       CHECK:   return %[[CST]]
 
 // -----
 
+func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
+  %c0 = tensor.splat %arg : tensor<2x2x2xf32>
+  %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
+      : tensor<2x2x2xf32> into tensor<2x4xf32>
+  return %0 : tensor<2x4xf32>
+}
+// CHECK-LABEL: @collapse_shape_splat
+// CHECK-SAME:    %[[ARG0:.+]]: f32
+//       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
+//   CHECK-NOT:   tensor.collapse_shape
+//       CHECK:   return %[[CST]]
+
+// -----
 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
   %c0 = arith.constant dense<42> : tensor<2x8xi16>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]


        


More information about the Mlir-commits mailing list