[Mlir-commits] [mlir] [mlir][Tensor] Use output_shape for ExpandShapeOp type inference (PR #118202)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 1 02:00:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the shape inference directly use the available output shape.

---
Full diff: https://github.com/llvm/llvm-project/pull/118202.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+23-96) 
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebb88bf695d4c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
 using namespace mlir;
 using namespace mlir::tensor;
 
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
-  for (const auto &map : enumerate(reassociation)) {
-    unsigned startPos =
-        cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
-    unsigned endPos =
-        cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
-    for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
-      expandedDimToCollapsedDim[dim] = map.index();
-    }
-  }
-  return expandedDimToCollapsedDim;
-}
-
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
       }));
 }
 
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
-    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
-    llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
-  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    // Static dimension: return Attribute.
-    return builder.getIndexAttr(dstStaticShape[dimIndex]);
-  }
-  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
-  unsigned startPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
-          .getPosition();
-  unsigned endPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
-          .getPosition();
-  int64_t linearizedStaticDim = 1;
-  for (auto d :
-       llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
-    if (d.index() + startPos == static_cast<unsigned>(dimIndex))
-      continue;
-    assert(!ShapedType::isDynamic(d.value()) &&
-           "single dimension cannot be expanded into multiple dynamic "
-           "dimensions");
-    linearizedStaticDim *= d.value();
+struct ReifyCollapseShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
+          ReifyCollapseShapeOp, CollapseShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
+    auto loc = op->getLoc();
+    auto collapseShape = cast<CollapseShapeOp>(op);
+    reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
+        b, loc, collapseShape.getSrc(),
+        collapseShape.getResultType().getShape(),
+        collapseShape.getReassociationMaps()));
+    return success();
   }
-  OpFoldResult sourceDim =
-      builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
-  // Dynamic dimension: return Value.
-  return affine::makeComposedAffineApply(
-             builder, loc,
-             AffineMap::get(
-                 0, 1,
-                 builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
-             sourceDim)
-      ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
-    OpBuilder &builder, Location loc, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
-      getExpandedDimToCollapsedDimMap(reassociation);
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
-        return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
-                                                  dstStaticShape, reassociation,
-                                                  expandedDimToCollapsedDim);
-      }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
-                                    ArrayRef<int64_t> dstStaticShape,
-                                    ArrayRef<AffineMap> reassocation) {
-  return dstStaticShape.size() >
-                 static_cast<size_t>(
-                     llvm::cast<ShapedType>(src.getType()).getRank())
-             ? getExpandedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation)
-             : getCollapsedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation);
-}
+};
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
-    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
-    reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
-        b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getReassociationMaps()));
+    auto expandShape = cast<ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> outputShape = getMixedValues(
+        expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
+    reifiedReturnShapes.push_back(outputShape);
     return success();
   }
 };
@@ -202,10 +131,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..0595ac2492c97a 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
   return %1, %2, %3 : index, index, index
 }
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @dim_reshape_expansion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[ARG1]]
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..d3889f23e7d742 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
 func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,8 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME:     %[[ARG0:.+]]: index
-// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index
+// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
 // CHECK-NEXT:   return %[[INIT]]
 
 func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/118202


More information about the Mlir-commits mailing list