[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.concat reifyResultShapes for static result dims (PR #75558)

Matthias Springer llvmlistbot at llvm.org
Thu Dec 14 23:25:00 PST 2023


================
@@ -637,17 +637,24 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
     }
   }
 
-  // Take the sum of the input sizes along the concatenated dim.
-  AffineExpr sum = builder.getAffineDimExpr(0);
-  SmallVector<OpFoldResult> sizes = {
-      builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
-  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
-    sum = sum + builder.getAffineDimExpr(idx + 1);
-    sizes.push_back(
-        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+  if (getType().isDynamicDim(dim)) {
+    // Take the sum of the input sizes along the concatenated dim.
+    AffineExpr sum = builder.getAffineDimExpr(0);
+    SmallVector<OpFoldResult> sizes = {
+        builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
+    for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+      sum = sum + builder.getAffineDimExpr(idx + 1);
+      sizes.push_back(
+          builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+    }
+    reifiedReturnShapes[0][dim] =
+        affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
----------------
matthias-springer wrote:

This must be wrapped in `getValueOrCreateConstantIndexOp` (see https://github.com/llvm/llvm-project/blob/main/mlir/lib/Interfaces/InferTypeOpInterface.cpp#L54).

https://github.com/llvm/llvm-project/pull/75558


More information about the Mlir-commits mailing list