[Mlir-commits] [mlir] d81a3c5 - [mlir] Fold tensor.reshape operations into tensor.from_elements.

Rob Suderman llvmlistbot at llvm.org
Tue Jan 25 15:57:22 PST 2022


Author: Rob Suderman
Date: 2022-01-25T15:54:57-08:00
New Revision: d81a3c51e7f76c4b3f7ed687f82a019168aad2da

URL: https://github.com/llvm/llvm-project/commit/d81a3c51e7f76c4b3f7ed687f82a019168aad2da
DIFF: https://github.com/llvm/llvm-project/commit/d81a3c51e7f76c4b3f7ed687f82a019168aad2da.diff

LOG: [mlir] Fold tensor.reshape operations into tensor.from_elements.

There is not much of a benefit to reshape a from element vs reloading it.
Updated to progagate shape manipulations into the output type of
tensor.from_elements.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D118201

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 42f57a9cf99bd..5ae13a613c427 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -798,20 +798,45 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
   }
 };
 
+/// Reshape of a FromElements can be replaced with a FromElements of the result
+/// type
+template <typename TensorReshapeOp>
+struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElements =
+        reshapeOp.src().template getDefiningOp<FromElementsOp>();
+    if (!fromElements)
+      return failure();
+
+    auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
+
+    if (!shapedTy.hasStaticShape())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
+                                                fromElements.elements());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<CollapseReshapeOps<ExpandShapeOp>,
               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>,
-              FoldReshapeWithConstant<ExpandShapeOp>>(context);
+              FoldReshapeWithConstant<ExpandShapeOp>,
+              FoldReshapeWithFromElements<ExpandShapeOp>>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<CollapseReshapeOps<CollapseShapeOp>,
               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
-              FoldReshapeWithConstant<CollapseShapeOp>>(context);
+              FoldReshapeWithConstant<CollapseShapeOp>,
+              FoldReshapeWithFromElements<CollapseShapeOp>>(context);
 }
 
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 10d39132a1126..e0ea5d777acb8 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1178,3 +1178,25 @@ func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tenso
 
   return %0 : tensor<2x3x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_collapse_shape_from_elements
+func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
+  // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
+  // CHECK: return %[[FROM]] : tensor<i32>
+  %0 = tensor.from_elements %arg0 : tensor<1xi32>
+  %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
+  return %1 : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_expand_shape_from_elements
+func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
+  // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
+  // CHECK: return %[[FROM]] : tensor<1xi32>
+  %0 = tensor.from_elements %arg0 : tensor<i32>
+  %1 = tensor.expand_shape %0 [] : tensor<i32> into tensor<1xi32>
+  return %1 : tensor<1xi32>
+}


        


More information about the Mlir-commits mailing list