[Mlir-commits] [mlir] c045955 - [mlir][tensor] Fold `tensor.reshape` for dynamic reshape (#88961)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 19 10:36:13 PDT 2024
Author: Rob Suderman
Date: 2024-04-19T10:36:09-07:00
New Revision: c045955501ed28fee7c40d8822a1aacc2022786e
URL: https://github.com/llvm/llvm-project/commit/c045955501ed28fee7c40d8822a1aacc2022786e
DIFF: https://github.com/llvm/llvm-project/commit/c045955501ed28fee7c40d8822a1aacc2022786e.diff
LOG: [mlir][tensor] Fold `tensor.reshape` for dynamic reshape (#88961)
If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we
know that the reshape is a no-op. Checking for this case lets us fold
away the computation.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 80bc04d62bbe84..3ff41ab22fbc42 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1580,6 +1580,41 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
+
+ auto source = getSource();
+ auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
+ auto resultTy = dyn_cast<RankedTensorType>(getType());
+
+ if (!sourceTy || !resultTy || sourceTy != resultTy)
+ return {};
+
+ if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
+ auto elements = fromElements.getElements();
+ bool dynamicNoop =
+ sourceTy.getRank() == static_cast<int64_t>(elements.size());
+ for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
+ auto element = elements[id];
+
+ if (auto cst = getConstantIntValue(element)) {
+ dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
+ continue;
+ }
+
+ if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
+ dynamicNoop &= dimOp.getSource() == source;
+
+ APSInt dim;
+ auto cst = getConstantIntValue(dimOp.getIndex());
+ dynamicNoop &=
+ cst.has_value() && cst.value() == static_cast<int64_t>(id);
+ continue;
+ }
+ }
+
+ if (dynamicNoop)
+ return source;
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..751c57eacd7ae5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
// -----
+// CHECK-LABEL: @reshape_fold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+ // CHECK: return %[[ARG0]]
+ return %reshape : tensor<?x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
+ // CHECK: tensor.reshape
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+ return %reshape : tensor<?x?xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @reshape_fold_3d_cst
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
+func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %d0 = arith.constant 5 : index
+ %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
+ %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
+ %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
+ // CHECK: return %[[ARG0]]
+ return %reshape : tensor<5x?x?xi32>
+}
+
+// -----
+
// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
// CHECK-LABEL: func @dim_out_of_bounds(
// CHECK: %[[IDX:.*]] = index.constant 28
More information about the Mlir-commits
mailing list