[Mlir-commits] [mlir] [mlir][tensor] Fix ReifyResultShapes implementation for tensor.concat (PR #74157)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 15:20:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Without folding the result of the initial tensor.dim, the ReifyResultShapes implementation would be incorrect because it would return a dynamic shape for a static result shape.

---
Full diff: https://github.com/llvm/llvm-project/pull/74157.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-1) 
- (modified) mlir/test/Dialect/Tensor/decompose-concat.mlir (+26-22) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 02146e8257b38e3..8970ea1c73b403f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -605,7 +605,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
   // Take the sum of the input sizes along the concatenated dim.
   AffineExpr sum = builder.getAffineDimExpr(0);
   SmallVector<OpFoldResult> sizes = {
-      builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
+      builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
   for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
     sum = sum + builder.getAffineDimExpr(idx + 1);
     sizes.push_back(
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index 5712c77a743d71b..b9497a61015af0a 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -1,19 +1,10 @@
 // RUN: mlir-opt -split-input-file -transform-interpreter -cse  %s | FileCheck %s
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.tensor.decompose_concat
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
-
 func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
-//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)>
 // CHECK-LABEL: func @decompose_dynamic_concat(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -22,24 +13,13 @@ func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?x
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
 //       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
+//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
 //       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
 //       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
 //       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
 //       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 //       CHECK:     return %[[CONCAT]] : tensor<?x?xf32>
 
-// -----
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.tensor.decompose_concat
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
-
 func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
                             %arg1 : tensor<2xf32>,
                             %arg2 : tensor<3xf32>,
@@ -55,3 +35,27 @@ func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
 //       CHECK:    tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
 //       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<10xf32>
+
+func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>,
+                               %arg1: tensor<1x?x64xf32>) -> tensor<1x?x128xf32> {
+  %0 = tensor.concat dim(2) %arg0, %arg1
+             : (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32>
+  return %0 : tensor<1x?x128xf32>
+}
+// CHECK-LABEL: func @decompose_static_concat_dim
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
+//       CHECK:    tensor.empty(%[[DIM]]) : tensor<1x?x128xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
+//       CHECK:     %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32>
+//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32>
+//       CHECK:    return %[[CONCAT]] : tensor<1x?x128xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_concat
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/74157


More information about the Mlir-commits mailing list