[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (PR #75558)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 14 20:53:49 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
When the concatenated dim is statically sized but the inputs are dynamically sized, reifyResultShapes must return the static shape. Fixes the implementation of the interface for tensor.concat in such cases.
---
Full diff: https://github.com/llvm/llvm-project/pull/75558.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+17-10)
- (modified) mlir/test/Dialect/Tensor/decompose-concat.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a257e5f4d9dc22..b3e56d591f0db9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -637,17 +637,24 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
}
}
- // Take the sum of the input sizes along the concatenated dim.
- AffineExpr sum = builder.getAffineDimExpr(0);
- SmallVector<OpFoldResult> sizes = {
- builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
- for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
- sum = sum + builder.getAffineDimExpr(idx + 1);
- sizes.push_back(
- builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+ if (getType().isDynamicDim(dim)) {
+ // Take the sum of the input sizes along the concatenated dim.
+ AffineExpr sum = builder.getAffineDimExpr(0);
+ SmallVector<OpFoldResult> sizes = {
+ builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+ for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+ sum = sum + builder.getAffineDimExpr(idx + 1);
+ sizes.push_back(
+ builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+ }
+ reifiedReturnShapes[0][dim] =
+ affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+ } else {
+ // If the result shape is static along the concatenated dim, use the static
+ // shape.
+ reifiedReturnShapes[0][dim] =
+ builder.getIndexAttr(getType().getDimSize(dim));
}
- reifiedReturnShapes[0][dim] =
- affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
// ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
// returns a Value for dynamic dimensions.
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index b9497a61015af0..159347b4f7aa28 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -51,6 +51,27 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
+
+func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
+ %arg1: tensor<1x?x?xf32>) -> tensor<1x?x128xf32> {
+ %0 = tensor.concat dim(2) %arg0, %arg1
+ : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
+ return %0 : tensor<1x?x128xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+// CHECK: tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
+// CHECK: %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+// CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
+// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
+// CHECK: %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+// CHECK: %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
+// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32>
+// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
transform.apply_patterns to %func_op {
``````````
</details>
https://github.com/llvm/llvm-project/pull/75558
More information about the Mlir-commits
mailing list