[Mlir-commits] [mlir] [mlir][Tensor] Rework `ReifyRankedShapedTypeInterface` implementation for `tensor.expand_shape` op. (PR #113501)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 23 15:45:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

The op carries the output-shape directly. This can be used directly. Also adds a method to get the shape as a `SmallVector<OpFoldResult>`.

---
Full diff: https://github.com/llvm/llvm-project/pull/113501.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3) 
- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+1) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+19-8) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+8-2) 
- (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (-8) 
- (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5) 
- (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..8203b9c0fab437 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1160,6 +1160,9 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
 
+    // Return output shape as mixes static/dynamic shapes.
+    SmallVector<OpFoldResult> getMixedOutputShape();
+
     // Infer the output shape for a tensor.expand_shape when it is possible
     // to do so.
     static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 4d7aa1ae17fdb1..9f0e01f1d8ca00 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,6 +125,9 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// Return a vector of OpFoldResults with the same size a staticValues, but
 /// all elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context);
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
                                          ValueRange dynamicValues, Builder &b);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 792e7229183064..416aac7d64aad5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebc458170337d6 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -144,15 +144,14 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
                    builder, loc, src, dstStaticShape, reassocation);
 }
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
+struct ReifyCollapseShapeOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+          ReifyCollapseShapeOp, CollapseShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
+    auto reshapeOp = cast<tensor::CollapseShapeOp>(op);
     reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
         b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
         reshapeOp.getReassociationMaps()));
@@ -162,6 +161,20 @@ struct ReifyExpandOrCollapseShapeOp
 
 namespace {
 
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifyResultShapes) const {
+    auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> resultShapes =
+        expandShapeOp.getMixedOutputShape();
+    reifyResultShapes.emplace_back(std::move(resultShapes));
+    return success();
+  }
+};
+
 struct ReifyPadOp
     : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyPadOp,
                                                              PadOp> {
@@ -202,10 +215,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..f1f33bd940f7d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1675,6 +1675,10 @@ ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
   return *outputShape;
 }
 
+SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
+  return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
+}
+
 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value src,
                           ArrayRef<ReassociationIndices> reassociation,
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 3eb6215a7a0b9b..f1166269f0a400 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -177,7 +177,8 @@ bool isEqualConstantIntOrValueArray(ArrayRef<OpFoldResult> ofrs1,
 /// elements for which ShapedType::isDynamic is true, will be replaced by
 /// dynamicValues.
 SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
-                                         ValueRange dynamicValues, Builder &b) {
+                                         ValueRange dynamicValues,
+                                         MLIRContext *context) {
   SmallVector<OpFoldResult> res;
   res.reserve(staticValues.size());
   unsigned numDynamic = 0;
@@ -186,10 +187,15 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
     int64_t value = staticValues[idx];
     res.push_back(ShapedType::isDynamic(value)
                       ? OpFoldResult{dynamicValues[numDynamic++]}
-                      : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])});
+                      : OpFoldResult{IntegerAttr::get(
+                            IntegerType::get(context, 64), staticValues[idx])});
   }
   return res;
 }
+SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
+                                         ValueRange dynamicValues, Builder &b) {
+  return getMixedValues(staticValues, dynamicValues, b.getContext());
+}
 
 /// Decompose a vector of mixed static or dynamic values into the corresponding
 /// pair of arrays. This is the inverse function of `getMixedValues`.
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 8cc4206dae6edf..c7f5fcb1d21fc8 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -48,14 +48,6 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
     assert(shapedType.getRank() ==
                static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
            "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
-      // reifyResultShapes must return:
-      // * Attribute for static dimensions
-      // * Value for dynamic dimensions
-      assert(shapedType.isDynamicDim(dim) ==
-                 reifiedReturnShapes[resultIdx][dim].is<Value>() &&
-             "incorrect implementation of ReifyRankedShapedTypeOpInterface");
-    }
     ++resultIdx;
   }
   // Assert that every shaped value result was reified.
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..3bc1f56d816d73 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
   return %1, %2, %3 : index, index, index
 }
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @dim_reshape_expansion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME:   %[[ARG1:.+]]: index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[ARG1]]
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..850bbcee340203 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
 func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,9 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME:     %[[ARG0:.+]]: index
-// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
 // CHECK-NEXT:   return %[[INIT]]
 
 func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/113501


More information about the Mlir-commits mailing list