[Mlir-commits] [mlir] d81a3c5 - [mlir] Fold tensor.reshape operations into tensor.from_elements.
Rob Suderman
llvmlistbot at llvm.org
Tue Jan 25 15:57:22 PST 2022
Author: Rob Suderman
Date: 2022-01-25T15:54:57-08:00
New Revision: d81a3c51e7f76c4b3f7ed687f82a019168aad2da
URL: https://github.com/llvm/llvm-project/commit/d81a3c51e7f76c4b3f7ed687f82a019168aad2da
DIFF: https://github.com/llvm/llvm-project/commit/d81a3c51e7f76c4b3f7ed687f82a019168aad2da.diff
LOG: [mlir] Fold tensor.reshape operations into tensor.from_elements.
There is not much of a benefit to reshape a from element vs reloading it.
Updated to progagate shape manipulations into the output type of
tensor.from_elements.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D118201
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 42f57a9cf99bd..5ae13a613c427 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -798,20 +798,45 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
}
};
+/// Reshape of a FromElements can be replaced with a FromElements of the result
+/// type
+template <typename TensorReshapeOp>
+struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElements =
+ reshapeOp.src().template getDefiningOp<FromElementsOp>();
+ if (!fromElements)
+ return failure();
+
+ auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
+
+ if (!shapedTy.hasStaticShape())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
+ fromElements.elements());
+ return success();
+ }
+};
+
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<ExpandShapeOp>,
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
- FoldReshapeWithConstant<ExpandShapeOp>>(context);
+ FoldReshapeWithConstant<ExpandShapeOp>,
+ FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseReshapeOps<CollapseShapeOp>,
CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
- FoldReshapeWithConstant<CollapseShapeOp>>(context);
+ FoldReshapeWithConstant<CollapseShapeOp>,
+ FoldReshapeWithFromElements<CollapseShapeOp>>(context);
}
OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 10d39132a1126..e0ea5d777acb8 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1178,3 +1178,25 @@ func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tenso
return %0 : tensor<2x3x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_collapse_shape_from_elements
+func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
+ // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
+ // CHECK: return %[[FROM]] : tensor<i32>
+ %0 = tensor.from_elements %arg0 : tensor<1xi32>
+ %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
+ return %1 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_expand_shape_from_elements
+func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
+ // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
+ // CHECK: return %[[FROM]] : tensor<1xi32>
+ %0 = tensor.from_elements %arg0 : tensor<i32>
+ %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
+ return %1 : tensor<1xi32>
+}
More information about the Mlir-commits
mailing list