[Mlir-commits] [mlir] 5668631 - [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>` (#174119)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 2 04:17:50 PST 2026
Author: Jakub Kuderski
Date: 2026-01-02T07:17:45-05:00
New Revision: 56686315f28d7386f01ddefae45849bb80b78b5c
URL: https://github.com/llvm/llvm-project/commit/56686315f28d7386f01ddefae45849bb80b78b5c
DIFF: https://github.com/llvm/llvm-project/commit/56686315f28d7386f01ddefae45849bb80b78b5c.diff
LOG: [ADT] Allow `TypeSwitch::Default` for `FailureOr<T>` (#174119)
Support specifying the default value without having to write a lambda,
e.g.: `.Default(failure());`.
Added:
Modified:
llvm/include/llvm/ADT/TypeSwitch.h
llvm/unittests/ADT/TypeSwitchTest.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
index 50ca1d5a6b5b6..4748ee29ad1c2 100644
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
namespace llvm {
@@ -135,6 +136,16 @@ class TypeSwitch : public detail::TypeSwitchBase<TypeSwitch<T, ResultT>, T> {
return Default(ResultT(std::nullopt));
}
+ /// Default for result types constructible from `LogicalResult` (e.g.,
+ /// `FailureOr<T>`).
+ template <typename ArgT = ResultT,
+ typename =
+ std::enable_if_t<std::is_constructible_v<ArgT, LogicalResult> &&
+ !std::is_same_v<ArgT, LogicalResult>>>
+ [[nodiscard]] ResultT Default(LogicalResult result) {
+ return Default(ResultT(result));
+ }
+
/// Declare default as unreachable, making sure that all cases were handled.
[[nodiscard]] ResultT DefaultUnreachable(
const char *message = "Fell off the end of a type-switch") {
diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp
index 0a9271785d168..3f4e2d2d3213f 100644
--- a/llvm/unittests/ADT/TypeSwitchTest.cpp
+++ b/llvm/unittests/ADT/TypeSwitchTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -183,3 +184,24 @@ TEST(TypeSwitchTest, DefaultNullptrForPointerLike) {
EXPECT_EQ(&foo, translate(DerivedA()).ptr);
EXPECT_EQ(nullptr, translate(DerivedD()).ptr);
}
+
+TEST(TypeSwitchTest, DefaultLogicalResultSuccess) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, LogicalResult>(&value)
+ .Case([](DerivedA *) { return success(); })
+ .Default(failure());
+ };
+ EXPECT_TRUE(succeeded(translate(DerivedA())));
+ EXPECT_TRUE(failed(translate(DerivedD())));
+}
+
+TEST(TypeSwitchTest, DefaultFailureOr) {
+ auto translate = [](auto value) {
+ return TypeSwitch<Base *, FailureOr<int>>(&value)
+ .Case([](DerivedA *) { return 42; })
+ .Default(failure());
+ };
+ EXPECT_TRUE(succeeded(translate(DerivedA())));
+ EXPECT_EQ(42, *translate(DerivedA()));
+ EXPECT_TRUE(failed(translate(DerivedD())));
+}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b8c1bad7c630f..b221b24dc5819 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4314,7 +4314,7 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
.Case([&](linalg::BatchMatmulOp op) {
return transposeBatchMatmul(rewriter, op, transposeLHS);
})
- .Default([&](Operation *op) { return failure(); });
+ .Default(failure());
if (failed(maybeTransformed))
return emitSilenceableFailure(target->getLoc()) << "not supported";
// Handle to the new Matmul operation with transposed filters
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 3bb5f8af821c0..419f6a0d3c010 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -972,7 +972,7 @@ class BubbleUpPackOpThroughReshapeOp final
.Case([&](tensor::ExpandShapeOp op) {
return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
})
- .Default([](Operation *) { return failure(); });
+ .Default(failure());
}
private:
@@ -1090,7 +1090,7 @@ class PushDownUnPackOpThroughReshapeOp final
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
controlFn);
})
- .Default([](Operation *) { return failure(); });
+ .Default(failure());
}
private:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..2b76c24334c0a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2676,7 +2676,7 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
}
/// Converts affine.apply Ops to arithmetic operations.
@@ -2783,7 +2783,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
results);
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
if (failed(vectorizeResult)) {
LDBG() << "Vectorization failed";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 2d4a18cc4b145..e9cd335835263 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -277,7 +277,7 @@ static LogicalResult setLoopAttr(const llvm::MDNode *node, Operation *op,
branchOp.setLoopAnnotationAttr(attr);
return success();
})
- .Default([](auto) { return failure(); });
+ .Default(failure());
}
/// Looks up all the alias scope attributes that map to the alias scope nodes
More information about the Mlir-commits
mailing list