[llvm-branch-commits] [mlir] [mlir][Transforms] Add 1:N `matchAndRewrite` overload (PR #116470)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Nov 18 04:09:22 PST 2024
================
@@ -1376,14 +1423,36 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
+ UnrealizedConversionCastOp targetCastOp;
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat), loc,
/*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
+ /*originalType=*/originalType, converter, &targetCastOp);
+ if (targetCastOp)
+ nTo1TempMaterializations.insert(targetCastOp);
mapping.map(argMat, targetMat);
}
}
+SmallVector<Value>
+ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
+ // Unpack unrealized_conversion_cast ops that were inserted as a N:1
+ // workaround.
+ auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!castOp)
+ return {value};
+ if (!nTo1TempMaterializations.contains(castOp))
+ return {value};
+ assert(castOp->getNumResults() == 1 && "expected single result");
+
+ SmallVector<Value> result;
----------------
matthias-springer wrote:
Note: The test case (output IR) gets so much simpler because we no longer call `materializeTargetConversion` in the patterns. Instead, we let the driver handle all materializations. The driver produces `unrealized_conversion_cast` ops. The driver attempts to fold all those cast ops before calling the materialization functions. (Once we build a `get_tuple` op, we cannot fold anymore.)
https://github.com/llvm/llvm-project/pull/116470
More information about the llvm-branch-commits
mailing list