[Mlir-commits] [mlir] [mlir][SPRIV][NFC] Avoid rollback in `TypeCastingOpPattern` (PR #136284)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 18 01:35:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, the pattern aborted and a rollback was triggered.
This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back.
Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
---
Full diff: https://github.com/llvm/llvm-project/pull/136284.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+12-8)
``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..434d7df853a5e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -847,24 +847,28 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
- auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
- op, dstType, adaptor.getOperands());
+ // Compute new rounding mode (if any).
+ std::optional<spirv::FPRoundingMode> rm = std::nullopt;
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
- if (auto rm =
- convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
- newOp->setAttr(
- getDecorationString(spirv::Decoration::FPRoundingMode),
- spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
- } else {
+ if (!(rm =
+ convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
}
}
+ // Create replacement op and attach rounding mode attribute (if any).
+ auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ op, dstType, adaptor.getOperands());
+ if (rm) {
+ newOp->setAttr(
+ getDecorationString(spirv::Decoration::FPRoundingMode),
+ spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+ }
}
return success();
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/136284
More information about the Mlir-commits
mailing list