[Mlir-commits] [mlir] fcd54b3 - [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (#75558)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 05:44:02 PST 2023


Author: Quinn Dawkins
Date: 2023-12-15T08:43:58-05:00
New Revision: fcd54b368e6713acd236dc47401b5292755900d7

URL: https://github.com/llvm/llvm-project/commit/fcd54b368e6713acd236dc47401b5292755900d7
DIFF: https://github.com/llvm/llvm-project/commit/fcd54b368e6713acd236dc47401b5292755900d7.diff

LOG: [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (#75558)

When the concatenated dim is statically sized but the inputs are
dynamically sized, reifyResultShapes must return the static shape. Fixes
the implementation of the interface for tensor.concat in such cases.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/decompose-concat.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a257e5f4d9dc22..9ef4ae84536841 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -629,33 +629,33 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
     if (!getType().isDynamicDim(i)) {
       reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
     } else if (!inferredResultType.isDynamicDim(i)) {
-      reifiedReturnShapes[0][i] =
-          builder.getIndexAttr(inferredResultType.getDimSize(i));
+      reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
+          builder, getLoc(),
+          builder.getIndexAttr(inferredResultType.getDimSize(i)));
     } else {
       reifiedReturnShapes[0][i] =
           builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
     }
   }
 
-  // Take the sum of the input sizes along the concatenated dim.
-  AffineExpr sum = builder.getAffineDimExpr(0);
-  SmallVector<OpFoldResult> sizes = {
-      builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
-  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
-    sum = sum + builder.getAffineDimExpr(idx + 1);
-    sizes.push_back(
-        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
-  }
-  reifiedReturnShapes[0][dim] =
-      affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
-
-  // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
-  // returns a Value for dynamic dimensions.
-  for (int64_t i = 0; i < rank; ++i) {
-    if (getType().isDynamicDim(i)) {
-      reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
-          builder, getLoc(), reifiedReturnShapes[0][i]);
+  if (getType().isDynamicDim(dim)) {
+    // Take the sum of the input sizes along the concatenated dim.
+    AffineExpr sum = builder.getAffineDimExpr(0);
+    SmallVector<OpFoldResult> sizes = {
+        builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+    for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+      sum = sum + builder.getAffineDimExpr(idx + 1);
+      sizes.push_back(
+          builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
     }
+    reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
+        builder, getLoc(),
+        affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
+  } else {
+    // If the result shape is static along the concatenated dim, use the static
+    // shape.
+    reifiedReturnShapes[0][dim] =
+        builder.getIndexAttr(getType().getDimSize(dim));
   }
   return success();
 }

diff  --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index b9497a61015af0..159347b4f7aa28 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -51,6 +51,27 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
 
+
+func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>,
+                               %arg1: tensor<1x?x?xf32>) -> tensor<1x?x128xf32> {
+  %0 = tensor.concat dim(2) %arg0, %arg1
+             : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32>
+  return %0 : tensor<1x?x128xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:     %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32>
+//       CHECK:     %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:       tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32>
+//       CHECK:     %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1]
+//  CHECK-SAME:        tensor<1x?x?xf32> into tensor<1x?x128xf32>
+//       CHECK:     return %[[CONCAT]] : tensor<1x?x128xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {


        


More information about the Mlir-commits mailing list