[llvm] [mlir] [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>` (PR #174119)

Jakub Kuderski via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 31 14:44:38 PST 2025


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/174119

Support specifying the default value without having to write a lambda, e.g.: `.Default(failure());`.

>From bc9a793ca596da38745dd74ff4c7b6dffc0f5675 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Dec 2025 17:42:28 -0500
Subject: [PATCH] [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>`

Support specifying the default value without having to write a lambda,
e.g.: `.Default(failure());`.
---
 llvm/include/llvm/ADT/TypeSwitch.h            | 11 ++++++++++
 llvm/unittests/ADT/TypeSwitchTest.cpp         | 22 +++++++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../Transforms/DataLayoutPropagation.cpp      |  4 ++--
 .../Linalg/Transforms/Vectorization.cpp       |  4 ++--
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        |  2 +-
 6 files changed, 39 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 50ca1d5a6b5b6..4748ee29ad1c2 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -18,6 +18,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 namespace llvm {
@@ -135,6 +136,16 @@ class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
     return Default(ResultT(std::nullopt));
   }
 
+  /// Default for result types constructible from `LogicalResult` (e.g.,
+  /// `FailureOr<T>`).
+  template <typename ArgT = ResultT,
+            typename =
+                std::enable_if_t<std::is_constructible_v<ArgT, LogicalResult> &&
+                                 !std::is_same_v<ArgT, LogicalResult>>>
+  [[nodiscard]] ResultT Default(LogicalResult result) {
+    return Default(ResultT(result));
+  }
+
   /// Declare default as unreachable, making sure that all cases were handled.
   [[nodiscard]] ResultT DefaultUnreachable(
       const char *message = "Fell off the end of a type-switch") {
diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp
index 0a9271785d168..3f4e2d2d3213f 100644
--- a/llvm/unittests/ADT/TypeSwitchTest.cpp
+++ b/llvm/unittests/ADT/TypeSwitchTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -183,3 +184,24 @@ TEST(TypeSwitchTest, DefaultNullptrForPointerLike) {
   EXPECT_EQ(&foo, translate(DerivedA()).ptr);
   EXPECT_EQ(nullptr, translate(DerivedD()).ptr);
 }
+
+TEST(TypeSwitchTest, DefaultLogicalResultSuccess) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, LogicalResult>(&value)
+        .Case([](DerivedA *) { return success(); })
+        .Default(failure());
+  };
+  EXPECT_TRUE(succeeded(translate(DerivedA())));
+  EXPECT_TRUE(failed(translate(DerivedD())));
+}
+
+TEST(TypeSwitchTest, DefaultFailureOr) {
+  auto translate = [](auto value) {
+    return TypeSwitch<Base *, FailureOr<int>>(&value)
+        .Case([](DerivedA *) { return 42; })
+        .Default(failure());
+  };
+  EXPECT_TRUE(succeeded(translate(DerivedA())));
+  EXPECT_EQ(42, *translate(DerivedA()));
+  EXPECT_TRUE(failed(translate(DerivedD())));
+}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b8c1bad7c630f..b221b24dc5819 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4314,7 +4314,7 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
           .Case([&](linalg::BatchMatmulOp op) {
             return transposeBatchMatmul(rewriter, op, transposeLHS);
           })
-          .Default([&](Operation *op) { return failure(); });
+          .Default(failure());
   if (failed(maybeTransformed))
     return emitSilenceableFailure(target->getLoc()) << "not supported";
   // Handle to the new Matmul operation with transposed filters
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3bb5f8af821c0..419f6a0d3c010 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -972,7 +972,7 @@ class BubbleUpPackOpThroughReshapeOp final
         .Case([&](tensor::ExpandShapeOp op) {
           return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
         })
-        .Default([](Operation *) { return failure(); });
+        .Default(failure());
   }
 
 private:
@@ -1090,7 +1090,7 @@ class PushDownUnPackOpThroughReshapeOp final
           return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
                                                     controlFn);
         })
-        .Default([](Operation *) { return failure(); });
+        .Default(failure());
   }
 
 private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..2b76c24334c0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2676,7 +2676,7 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
         return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
       })
-      .Default([](auto) { return failure(); });
+      .Default(failure());
 }
 
 /// Converts affine.apply Ops to arithmetic operations.
@@ -2783,7 +2783,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
             return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
                                             results);
           })
-          .Default([](auto) { return failure(); });
+          .Default(failure());
 
   if (failed(vectorizeResult)) {
     LDBG() << "Vectorization failed";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 2d4a18cc4b145..e9cd335835263 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -277,7 +277,7 @@ static LogicalResult setLoopAttr(const llvm::MDNode *node, Operation *op,
         branchOp.setLoopAnnotationAttr(attr);
         return success();
       })
-      .Default([](auto) { return failure(); });
+      .Default(failure());
 }
 
 /// Looks up all the alias scope attributes that map to the alias scope nodes



More information about the llvm-commits mailing list