[Mlir-commits] [mlir] [mlir][SPRIV][NFC] Avoid rollback in `TypeCastingOpPattern` (PR #136284)
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 18 01:34:47 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136284
This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, the pattern aborted and a rollback was triggered.
This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back.
Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
>From 3b82d8e9998ce43a2b5fdeff4e1c1d3a12ba7ba5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 18 Apr 2025 10:30:49 +0200
Subject: [PATCH] [mlir][SPRIV][NFC] Avoid rollback in `TypeCastingOpPattern`
This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, a rollback was triggered.
This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back.
Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 20 +++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..434d7df853a5e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -847,24 +847,28 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
// Then we can just erase this operation by forwarding its operand.
rewriter.replaceOp(op, adaptor.getOperands().front());
} else {
- auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
- op, dstType, adaptor.getOperands());
+ // Compute new rounding mode (if any).
+ std::optional<spirv::FPRoundingMode> rm = std::nullopt;
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
if (arith::RoundingModeAttr roundingMode =
roundingModeOp.getRoundingModeAttr()) {
- if (auto rm =
- convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
- newOp->setAttr(
- getDecorationString(spirv::Decoration::FPRoundingMode),
- spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
- } else {
+ if (!(rm =
+ convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
}
}
}
+ // Create replacement op and attach rounding mode attribute (if any).
+ auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ op, dstType, adaptor.getOperands());
+ if (rm) {
+ newOp->setAttr(
+ getDecorationString(spirv::Decoration::FPRoundingMode),
+ spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+ }
}
return success();
}
More information about the Mlir-commits
mailing list