[Mlir-commits] [mlir] [mlir][linalg][elementwise] Fold transpose into new elementwise (PR #130207)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 6 17:12:07 PST 2025
================
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+namespace {
+struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
+ using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ElementwiseOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ SmallVector<Value> newIns;
+ SmallVector<AffineMap> newMaps;
+ for (OpOperand *operand : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(operand);
+ auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
+
+ if (!map.isIdentity() || !transposeOp) {
+ // push in original operand and its map.
+ newIns.push_back(operand->get());
+ newMaps.push_back(map);
+ continue;
+ }
+ newIns.push_back(transposeOp.getInput());
+ // push in transposeOp's inverse permutation map.
+ newMaps.push_back(transposeOp.getMatchingIndexingMap(
+ transposeOp.getDpsInputOperand(0)));
+ changed = true;
+ }
+ if (!changed)
+ return failure();
+ newMaps.push_back(op.getIndexingMapsArray().back());
+
+ rewriter.replaceOpWithNewOp<ElementwiseOp>(
+ op, newIns, op.getDpsInits()[0], op.getKindAttr(),
+ rewriter.getAffineMapArrayAttr(newMaps));
+ return success();
+ }
+};
+} // namespace
+void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
----------------
MaheshRavishankar wrote:
I dont think this should be part of canonicalization. Maybe this requires a `populate*` method.
https://github.com/llvm/llvm-project/pull/130207
More information about the Mlir-commits
mailing list