[Mlir-commits] [mlir] [mlir][SPRIV][NFC] Avoid rollback in `TypeCastingOpPattern` (PR #136284)

Matthias Springer llvmlistbot at llvm.org
Fri Apr 18 01:34:47 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136284

This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, the pattern aborted and a rollback was triggered.

This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.

>From 3b82d8e9998ce43a2b5fdeff4e1c1d3a12ba7ba5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 18 Apr 2025 10:30:49 +0200
Subject: [PATCH] [mlir][SPRIV][NFC] Avoid rollback in `TypeCastingOpPattern`

This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, a rollback was triggered.

This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 20 +++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..434d7df853a5e 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -847,24 +847,28 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(op, adaptor.getOperands().front());
     } else {
-      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
-          op, dstType, adaptor.getOperands());
+      // Compute new rounding mode (if any).
+      std::optional<spirv::FPRoundingMode> rm = std::nullopt;
       if (auto roundingModeOp =
               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
         if (arith::RoundingModeAttr roundingMode =
                 roundingModeOp.getRoundingModeAttr()) {
-          if (auto rm =
-                  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
-            newOp->setAttr(
-                getDecorationString(spirv::Decoration::FPRoundingMode),
-                spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
-          } else {
+          if (!(rm =
+                    convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
             return rewriter.notifyMatchFailure(
                 op->getLoc(),
                 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
           }
         }
       }
+      // Create replacement op and attach rounding mode attribute (if any).
+      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+          op, dstType, adaptor.getOperands());
+      if (rm) {
+        newOp->setAttr(
+            getDecorationString(spirv::Decoration::FPRoundingMode),
+            spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+      }
     }
     return success();
   }



More information about the Mlir-commits mailing list