[Mlir-commits] [mlir] 5c18ae3 - [MLIR][Tensor] Canonicalize expand/collapse_shape of splat to splat
Rob Suderman
llvmlistbot at llvm.org
Wed Jan 4 13:10:32 PST 2023
Author: liqinweng
Date: 2023-01-04T13:07:55-08:00
New Revision: 5c18ae3135d1ff4b9e554480da78bc93e35ef00a
URL: https://github.com/llvm/llvm-project/commit/5c18ae3135d1ff4b9e554480da78bc93e35ef00a
DIFF: https://github.com/llvm/llvm-project/commit/5c18ae3135d1ff4b9e554480da78bc93e35ef00a.diff
LOG: [MLIR][Tensor] Canonicalize expand/collapse_shape of splat to splat
Collapsing / expanding a splatted value can be replaced with a single `tensor.splat` operation. Replace
these cases with a simple `tensor.splat` operation.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D140552
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c87a003fab3bb..cd962456ba424 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1382,6 +1382,24 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
}
};
+// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
+template <typename TensorReshapeOp>
+class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
+public:
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
+ if (!splatOp)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tensor::SplatOp>(
+ reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
+ return success();
+ }
+};
+
/// Reshape of a FromElements can be replaced with a FromElements of the
/// result type
template <typename TensorReshapeOp>
@@ -1523,6 +1541,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
+ FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
FoldDimOfCollapseShape>(context);
}
@@ -1533,6 +1552,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
+ FoldReshapeWithSplat<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);
}
@@ -1540,6 +1560,7 @@ void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
}
+
OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 2b11a33681679..6267c269ab0b7 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1013,9 +1013,34 @@ func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
// CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
// CHECK-NOT: tensor.expand_shape
// CHECK: return %[[CST]]
+// -----
+func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
+ %c0 = tensor.splat %arg : tensor<2x4xf32>
+ %0 = tensor.expand_shape %c0 [[0], [1, 2]]
+ : tensor<2x4xf32> into tensor<2x2x2xf32>
+ return %0 : tensor<2x2x2xf32>
+}
+// CHECK-LABEL: @expand_shape_splat
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
+// CHECK-NOT: tensor.expand_shape
+// CHECK: return %[[CST]]
// -----
+func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
+ %c0 = tensor.splat %arg : tensor<2x2x2xf32>
+ %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
+ : tensor<2x2x2xf32> into tensor<2x4xf32>
+ return %0 : tensor<2x4xf32>
+}
+// CHECK-LABEL: @collapse_shape_splat
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: return %[[CST]]
+
+// -----
func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
%c0 = arith.constant dense<42> : tensor<2x8xi16>
%0 = tensor.expand_shape %c0 [[0], [1, 2]]
More information about the Mlir-commits
mailing list