[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)
Ian Wood
llvmlistbot at llvm.org
Thu Oct 24 13:36:03 PDT 2024
================
@@ -1982,14 +1983,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
+/// matching constant output_shape operands of the expand. This makes the
+/// `tensor.expand_shape` more static and creates a consumer cast that can be
+/// propagated further.
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+ if (!canFoldIntoConsumerOp(castOp))
+ return failure();
+
+ ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
+ SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+
+ SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+ SmallVector<Value> dynamicOutputShape;
+ auto outputIt = expandOp.getOutputShape().begin();
+
+ for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+ for (uint64_t outDim : innerReassoc) {
+ if (!ShapedType::isDynamic(newOutputShape[outDim]))
+ continue;
+
+ // If the cast's src type is dynamic, don't infer any of the
+ // corresponding expanded dimensions. `tensor.expand_shape` requires at
+ // least one of the expanded dimensions to be dynamic if the input is
+ // dynamic.
+ Value val = *outputIt;
+ ++outputIt;
+ if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+ dynamicOutputShape.push_back(val);
+ continue;
+ }
+
+ APInt cst;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[outDim] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
+ }
+ }
+
+ // Couldn't match any values, nothing to change
+ if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
+ return failure();
+
+ // Calculate the input shape from the output
+ SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
+ for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
----------------
IanWood1 wrote:
Good catch, I need to get out of the habit of using postincrement. I'll change it to `llvm::seq`
https://github.com/llvm/llvm-project/pull/112265
More information about the Mlir-commits
mailing list