[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 21 22:53:45 PDT 2024


================
@@ -1982,14 +1983,91 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    const ArrayRef<int64_t> castSrcShape =
+        castOp.getSource().getType().getShape();
+    const SmallVector<ReassociationIndices, 4> reassoc =
+        expandOp.getReassociationIndices();
+
+    SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+    SmallVector<Value> dynamicOutputShape;
+    auto outputIt = expandOp.getOutputShape().begin();
+
+    for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+      for (const uint64_t outDim : innerReassoc) {
+        if (!ShapedType::isDynamic(newOutputShape[outDim]))
+          continue;
+
+        // If the cast's src type is dynamic, don't infer any of the
+        // corresponding expanded dimensions. `tensor.expand_shape` requires at
+        // least one of the expanded dimensions to be dynamic if the input is
+        // dynamic.
+        Value val = *outputIt;
+        ++outputIt;
+        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+          dynamicOutputShape.push_back(val);
+          continue;
+        }
+
+        APInt cst;
+        if (matchPattern(val, m_ConstantInt(&cst))) {
+          newOutputShape[outDim] = cst.getSExtValue();
+        } else {
+          dynamicOutputShape.push_back(val);
+        }
+      }
+    }
+
+    // Couldn't match any values, nothing to change
+    if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
+      return failure();
+
+    // Calculate the input shape from the output
+    SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
----------------
MaheshRavishankar wrote:

Do you need this? Isnt the input shape the same as the source of the cast operation?

https://github.com/llvm/llvm-project/pull/112265


More information about the Mlir-commits mailing list