[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (PR #75558)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 20:53:49 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

When the concatenated dim is statically sized but the inputs are dynamically sized, reifyResultShapes must return the static shape. Fixes the implementation of the interface for tensor.concat in such cases.

---
Full diff: https://github.com/llvm/llvm-project/pull/75558.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+17-10) 
- (modified) mlir/test/Dialect/Tensor/decompose-concat.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a257e5f4d9dc22..b3e56d591f0db9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -637,17 +637,24 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
     }
   }
 
-  // Take the sum of the input sizes along the concatenated dim.
-  AffineExpr sum = builder.getAffineDimExpr(0);
-  SmallVector<OpFoldResult> sizes = {
-      builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
-  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
-    sum = sum + builder.getAffineDimExpr(idx + 1);
-    sizes.push_back(
-        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+  if (getType().isDynamicDim(dim)) {
+    // Take the sum of the input sizes along the concatenated dim.
+    AffineExpr sum = builder.getAffineDimExpr(0);
+    SmallVector<OpFoldResult> sizes = {
+        builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+    for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+      sum = sum + builder.getAffineDimExpr(idx + 1);
+      sizes.push_back(
+          builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+    }
+    reifiedReturnShapes[0][dim] =
+        affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+  } else {
+    // If the result shape is static along the concatenated dim, use the static
+    // shape.
+    reifiedReturnShapes[0][dim] =
+        builder.getIndexAttr(getType().getDimSize(dim));
   }
-  reifiedReturnShapes[0][dim] =
-      affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
 
   // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
   // returns a Value for dynamic dimensions.
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index b9497a61015af0..159347b4f7aa28 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -51,6 +51,27 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
 
+
+func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
+                               %arg1: tensor<1x?x?xf32>) -> tensor<1x?x128xf32> {
+  %0 = tensor.concat dim(2) %arg0, %arg1
+             : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
+  return %0 : tensor<1x?x128xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
+//       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     return %[[CONCAT]] : tensor<1x?x128xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {

``````````

</details>


https://github.com/llvm/llvm-project/pull/75558


More information about the Mlir-commits mailing list