[Mlir-commits] [mlir] [mlir][Tensor] Add pattern to fold concats of empty. (PR #98994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 15 22:48:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to `populateFoldTensorEmptyPatterns`.

---
Full diff: https://github.com/llvm/llvm-project/pull/98994.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (+35-2) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+38) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index 43ad0acaf7420..60b0c3e759b6c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -136,6 +136,38 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+// Fold concat operation where all the operands are empty.
+struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
+      return failure();
+    }
+    auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
+    if (!firstEmptyOp) {
+      return failure();
+    }
+    auto isDefinedByEmptyOp = [](Value v) -> bool {
+      return v.getDefiningOp<tensor::EmptyOp>();
+    };
+    if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "not all operands are defined by an empty op");
+    }
+    SmallVector<SmallVector<OpFoldResult>> resultShape;
+    if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
+      return rewriter.notifyMatchFailure(concatOp,
+                                         "failed to get result shape");
+    }
+    rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
+        concatOp, resultShape[0], concatOp.getResultType().getElementType());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -144,6 +176,7 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
       patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
-  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
-      patterns.getContext(), /*benefit=*/1);
+  patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
+               FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
+                                            /*benefit=*/1);
 }
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index e94f6ec7ec56e..5beb8c250aa10 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -164,3 +164,41 @@ func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
 //       CHECK:   tensor.empty{{.*}} : tensor<?x10x40xf32>
 //       CHECK:   tensor.extract_slice
 //       CHECK:   tensor.extract_slice
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.fold_tensor_empty
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
+
+func.func @concats_of_empty(
+    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> tensor<5x?x?xf32>
+{
+  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+  %1 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+  %2 = tensor.concat dim(1) %0, %1 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  return %2 : tensor<5x?x?xf32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @concats_of_empty(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+//       CHECK:   %[[D2:.+]] = tensor.dim %[[EMPTY0]], %[[C2]]
+//   CHECK-DAG:   %[[D0_1:.+]] = tensor.dim %[[EMPTY0]], %[[C1]]
+//   CHECK-DAG:   %[[D1_1:.+]] = tensor.dim %[[EMPTY1]], %[[C1]]
+//   CHECK-DAG:   %[[SUM:.+]] = affine.apply #[[MAP]]()[%[[D0_1]], %[[D1_1]]]
+//       CHECK:   %[[NEW_EMPTY:.+]] = tensor.empty(%[[SUM]], %[[D2]])
+//       CHECK:   return %[[NEW_EMPTY]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98994


More information about the Mlir-commits mailing list