[llvm] [mlir] [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>` (PR #174119)
Jakub Kuderski via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 31 14:44:38 PST 2025
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/174119
Support specifying the default value without having to write a lambda, e.g.: `.Default(failure());`.
>From bc9a793ca596da38745dd74ff4c7b6dffc0f5675 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 31 Dec 2025 17:42:28 -0500
Subject: [PATCH] [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>`
Support specifying the default value without having to write a lambda,
e.g.: `.Default(failure());`.
---
llvm/include/llvm/ADT/TypeSwitch.h | 11 ++++++++++
llvm/unittests/ADT/TypeSwitchTest.cpp | 22 +++++++++++++++++++
.../TransformOps/LinalgTransformOps.cpp | 2 +-
.../Transforms/DataLayoutPropagation.cpp | 4 ++--
.../Linalg/Transforms/Vectorization.cpp | 4 ++--
.../LLVMIR/LLVMIRToLLVMTranslation.cpp | 2 +-
6 files changed, 39 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 50ca1d5a6b5b6..4748ee29ad1c2 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
namespace llvm {
@@ -135,6 +136,16 @@ class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
return Default(ResultT(std::nullopt));
}
+ /// Default for result types constructible from `LogicalResult` (e.g.,
+ /// `FailureOr<T>`).
+ template <typename ArgT = ResultT,
+ typename =
+ std::enable_if_t<std::is_constructible_v<ArgT, LogicalResult> &&
+ !std::is_same_v<ArgT, LogicalResult>>>
+ [[nodiscard]] ResultT Default(LogicalResult result) {
+ return Default(ResultT(result));
+ }
+
/// Declare default as unreachable, making sure that all cases were handled.
[[nodiscard]] ResultT DefaultUnreachable(
const char *message = "Fell off the end of a type-switch") {
diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp
index 0a9271785d168..3f4e2d2d3213f 100644
--- a/llvm/unittests/ADT/TypeSwitchTest.cpp
+++ b/llvm/unittests/ADT/TypeSwitchTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -183,3 +184,24 @@ TEST(TypeSwitchTest, DefaultNullptrForPointerLike) {
EXPECT_EQ(&foo, translate(DerivedA()).ptr);
EXPECT_EQ(nullptr, translate(DerivedD()).ptr);
}
+
+TEST(TypeSwitchTest, DefaultLogicalResultSuccess) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, LogicalResult>(&value)
+ .Case([](DerivedA *) { return success(); })
+ .Default(failure());
+ };
+ EXPECT_TRUE(succeeded(translate(DerivedA())));
+ EXPECT_TRUE(failed(translate(DerivedD())));
+}
+
+TEST(TypeSwitchTest, DefaultFailureOr) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, FailureOr<int>>(&value)
+ .Case([](DerivedA *) { return 42; })
+ .Default(failure());
+ };
+ EXPECT_TRUE(succeeded(translate(DerivedA())));
+ EXPECT_EQ(42, *translate(DerivedA()));
+ EXPECT_TRUE(failed(translate(DerivedD())));
+}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b8c1bad7c630f..b221b24dc5819 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4314,7 +4314,7 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
.Case([&](linalg::BatchMatmulOp op) {
return transposeBatchMatmul(rewriter, op, transposeLHS);
})
- .Default([&](Operation *op) { return failure(); });
+ .Default(failure());
if (failed(maybeTransformed))
return emitSilenceableFailure(target->getLoc()) << "not supported";
// Handle to the new Matmul operation with transposed filters
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3bb5f8af821c0..419f6a0d3c010 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -972,7 +972,7 @@ class BubbleUpPackOpThroughReshapeOp final
.Case([&](tensor::ExpandShapeOp op) {
return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
})
- .Default([](Operation *) { return failure(); });
+ .Default(failure());
}
private:
@@ -1090,7 +1090,7 @@ class PushDownUnPackOpThroughReshapeOp final
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
controlFn);
})
- .Default([](Operation *) { return failure(); });
+ .Default(failure());
}
private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..2b76c24334c0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2676,7 +2676,7 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
}
/// Converts affine.apply Ops to arithmetic operations.
@@ -2783,7 +2783,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
results);
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
if (failed(vectorizeResult)) {
LDBG() << "Vectorization failed";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 2d4a18cc4b145..e9cd335835263 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -277,7 +277,7 @@ static LogicalResult setLoopAttr(const llvm::MDNode *node, Operation *op,
branchOp.setLoopAnnotationAttr(attr);
return success();
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
}
/// Looks up all the alias scope attributes that map to the alias scope nodes
More information about the llvm-commits
mailing list