[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (PR #75558)

Quinn Dawkins llvmlistbot at llvm.org
Thu Dec 14 20:53:16 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/75558

When the concatenated dim is statically sized but the inputs are dynamically sized, reifyResultShapes must return the static shape. Fixes the implementation of the interface for tensor.concat in such cases.

>From 7b419d24f9995c093d58a59d27acedda4d20b044 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 14 Dec 2023 23:33:58 -0500
Subject: [PATCH] [mlir][tensor] Fix tensor.concat reifyResultShapes for static
 concatenated result dims

When the concatenated dim is statically sized but the inputs are
dynamically sized, reifyResultShapes must return the static shape.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 27 ++++++++++++-------
 .../test/Dialect/Tensor/decompose-concat.mlir | 21 +++++++++++++++
 2 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a257e5f4d9dc22..b3e56d591f0db9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -637,17 +637,24 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
     }
   }
 
-  // Take the sum of the input sizes along the concatenated dim.
-  AffineExpr sum = builder.getAffineDimExpr(0);
-  SmallVector<OpFoldResult> sizes = {
-      builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
-  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
-    sum = sum + builder.getAffineDimExpr(idx + 1);
-    sizes.push_back(
-        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+  if (getType().isDynamicDim(dim)) {
+    // Take the sum of the input sizes along the concatenated dim.
+    AffineExpr sum = builder.getAffineDimExpr(0);
+    SmallVector<OpFoldResult> sizes = {
+        builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+    for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+      sum = sum + builder.getAffineDimExpr(idx + 1);
+      sizes.push_back(
+          builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+    }
+    reifiedReturnShapes[0][dim] =
+        affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+  } else {
+    // If the result shape is static along the concatenated dim, use the static
+    // shape.
+    reifiedReturnShapes[0][dim] =
+        builder.getIndexAttr(getType().getDimSize(dim));
   }
-  reifiedReturnShapes[0][dim] =
-      affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
 
   // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
   // returns a Value for dynamic dimensions.
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index b9497a61015af0..159347b4f7aa28 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -51,6 +51,27 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
 
+
+func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
+                               %arg1: tensor<1x?x?xf32>) -> tensor<1x?x128xf32> {
+  %0 = tensor.concat dim(2) %arg0, %arg1
+             : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
+  return %0 : tensor<1x?x128xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
+//       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     return %[[CONCAT]] : tensor<1x?x128xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {



More information about the Mlir-commits mailing list