[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)

Ian Wood llvmlistbot at llvm.org
Thu Oct 24 13:36:03 PDT 2024


================
@@ -1982,14 +1983,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
+/// matching constant output_shape operands of the expand. This makes the
+/// `tensor.expand_shape` more static and creates a consumer cast that can be
+/// propagated further.
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
+    SmallVector<ReassociationIndices, 4> reassoc =
+        expandOp.getReassociationIndices();
+
+    SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+    SmallVector<Value> dynamicOutputShape;
+    auto outputIt = expandOp.getOutputShape().begin();
+
+    for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+      for (uint64_t outDim : innerReassoc) {
+        if (!ShapedType::isDynamic(newOutputShape[outDim]))
+          continue;
+
+        // If the cast's src type is dynamic, don't infer any of the
+        // corresponding expanded dimensions. `tensor.expand_shape` requires at
+        // least one of the expanded dimensions to be dynamic if the input is
+        // dynamic.
+        Value val = *outputIt;
+        ++outputIt;
+        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+          dynamicOutputShape.push_back(val);
+          continue;
+        }
+
+        APInt cst;
+        if (matchPattern(val, m_ConstantInt(&cst))) {
+          newOutputShape[outDim] = cst.getSExtValue();
+        } else {
+          dynamicOutputShape.push_back(val);
+        }
+      }
+    }
+
+    // Couldn't match any values, nothing to change
+    if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
+      return failure();
+
+    // Calculate the input shape from the output
+    SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
+    for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
----------------
IanWood1 wrote:

Good catch, I need to get out of the habit of using postincrement. I'll change it to `llvm::seq`

https://github.com/llvm/llvm-project/pull/112265


More information about the Mlir-commits mailing list