[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)

Ian Wood llvmlistbot at llvm.org
Tue Oct 22 09:33:48 PDT 2024


================
@@ -1982,14 +1983,91 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    const ArrayRef<int64_t> castSrcShape =
+        castOp.getSource().getType().getShape();
+    const SmallVector<ReassociationIndices, 4> reassoc =
+        expandOp.getReassociationIndices();
+
+    SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+    SmallVector<Value> dynamicOutputShape;
+    auto outputIt = expandOp.getOutputShape().begin();
+
+    for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+      for (const uint64_t outDim : innerReassoc) {
+        if (!ShapedType::isDynamic(newOutputShape[outDim]))
+          continue;
+
+        // If the cast's src type is dynamic, don't infer any of the
+        // corresponding expanded dimensions. `tensor.expand_shape` requires at
+        // least one of the expanded dimensions to be dynamic if the input is
+        // dynamic.
+        Value val = *outputIt;
+        ++outputIt;
+        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+          dynamicOutputShape.push_back(val);
+          continue;
+        }
+
+        APInt cst;
+        if (matchPattern(val, m_ConstantInt(&cst))) {
+          newOutputShape[outDim] = cst.getSExtValue();
+        } else {
+          dynamicOutputShape.push_back(val);
+        }
+      }
+    }
+
+    // Couldn't match any values, nothing to change
+    if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
+      return failure();
+
+    // Calculate the input shape from the output
+    SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
----------------
IanWood1 wrote:

>From the `partial_sink_expand_of_cast` test case:

```mlir
  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
      : tensor<?x?xf32> into tensor<?x?x?xf32>
```

`tensor.expand_shape`'s src type cannot become fully static because the op requires a dynamic input dim if the output is dynamic. The input cast becomes `tensor<10x10xf32> to tensor<?x10xf32>` instead of being fully removed. I could just bail on cases where not all SSA values can be matched (if the input dim can be made static). That way teh input shape would be the same as the `tensor.cast` at the cost of not being able to propagate any of the static dim info

https://github.com/llvm/llvm-project/pull/112265


More information about the Mlir-commits mailing list