[Mlir-commits] [mlir] [mlir] Simplify Default cases in type switches. NFC. (PR #165767)

Jakub Kuderski llvmlistbot at llvm.org
Thu Oct 30 11:52:31 PDT 2025


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/165767

Use default values instead of lambdas when possible. `std::nullopt` and `nullptr` can be used now because of https://github.com/llvm/llvm-project/pull/165724.

>From 608cf7bb0488d94edaace1674450cf01143e3f45 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Thu, 30 Oct 2025 14:50:23 -0400
Subject: [PATCH] [mlir] Simplify Default cases in type switches

Use default values instead of lambdas when possible.
`std::nullopt` and `nullptr` can be used now because of
https://github.com/llvm/llvm-project/pull/165724.
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp         | 2 +-
 mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp         | 2 +-
 mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp           | 2 +-
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp         | 2 +-
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp                      | 2 +-
 mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp       | 6 +++---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp                  | 2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp               | 2 +-
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp                    | 2 +-
 mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp            | 2 +-
 .../lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp        | 2 +-
 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp       | 2 +-
 mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp    | 2 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp                      | 4 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp                    | 2 +-
 mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp       | 2 +-
 mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp    | 2 +-
 mlir/lib/TableGen/Type.cpp                                  | 2 +-
 mlir/lib/Target/LLVMIR/DebugTranslation.cpp                 | 6 +++---
 21 files changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 41e333c621eda..3a307a0756d93 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
       .Case([](Float6E2M3FNType) { return 2u; })
       .Case([](Float6E3M2FNType) { return 3u; })
       .Case([](Float4E2M1FNType) { return 4u; })
-      .Default([](Type) { return std::nullopt; });
+      .Default(std::nullopt);
 }
 
 /// If there is a scaled MFMA instruction for the input element types `aType`
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 247dba101cfc1..cfdcd9cc2d86d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) {
                         current = op.getSource();
                         return false;
                       })
-                      .Default([](Operation *) { return false; });
+                      .Default(false);
 
     if (!skipOp) {
       break;
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 25f1e1b184d61..425594b3382f0 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
           }
           return std::nullopt;
         })
-        .Default([](auto) { return std::nullopt; });
+        .Default(std::nullopt);
   }
 
   static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..91c1aa55fdb4e 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) {
           [](auto floatAttr) { return floatAttr.getValue().isZero(); })
       .Case<IntegerAttr>(
           [](auto intAttr) { return intAttr.getValue().isZero(); })
-      .Default([](auto) { return false; });
+      .Default(false);
 }
 
 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 898d76ce8d9b5..980442efdf708 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
           .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
           .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
           .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
-          .Default([](Operation *op) { return std::nullopt; });
+          .Default(std::nullopt);
   if (!maybeKind) {
     return std::nullopt;
   }
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index d2c2138d61638..025d1acf8d6ba 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -330,7 +330,7 @@ static Value getBase(Value v) {
               v = op.getSrc();
               return true;
             })
-            .Default([](Operation *) { return false; });
+            .Default(false);
     if (!shouldContinue)
       break;
   }
@@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {
       .Case([](memref::TransposeOp transpose) { return transpose.getIn(); })
       .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
           [](auto op) { return op.getSrc(); })
-      .Default([](Operation *) { return Value(); });
+      .Default(nullptr);
 }
 
 /// Returns `true` if the given operation is known to capture the given value,
@@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {
       // These operations are known not to capture.
       .Case([](memref::DeallocOp) { return false; })
       // By default, we don't know anything.
-      .Default([](Operation *) { return std::nullopt; });
+      .Default(std::nullopt);
 }
 
 /// Returns `true` if the value may be captured by any of its users, i.e., if
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3eae67f4c1f98..2731069d6ef54 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
                        return structType.getBody()[memberIndex];
                      return nullptr;
                    })
-                   .Default(Type(nullptr));
+                   .Default(nullptr);
   }
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index cee943d2d86c6..7d9058c262562 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
           .Case<IntegerType, FloatType>([](auto type) {
             return type.getWidth() % 8 == 0 && type.getWidth() > 0;
           })
-          .Default([](Type) { return false; });
+          .Default(false);
   if (!canConvertType)
     return false;
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ac35eea66e9d6..ce93d18f56d39 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
           // clang-format on
           .Case<PtrLikeTypeInterface>(
               [](Type type) { return isCompatiblePtrType(type); })
-          .Default([](Type) { return false; });
+          .Default(false);
 
   if (!result)
     compatibleTypes.erase(type);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8b89244486339..b09112bcf0bb7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4499,7 +4499,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
             maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
             return true;
           })
-          .Default([&](Operation *op) { return false; });
+          .Default(false);
 
   if (!supported) {
     DiagnosedSilenceableFailure diag =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f05ffa8334d9c..6519c4f64dd05 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
                 tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
               return complex::CreateOp::create(b, t, tmp, tmp);
             })
-            .Default([](auto) { return Value(); });
+            .Default(nullptr);
     if (!fillVal)
       return failure();
     linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index 27ccf3c2ba148..6becc1f29afbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
                 ValueRange{input, collapsedKernel, iZp, kZp},
                 ValueRange{collapsedInit}, stride, dilation);
           })
-          .Default([](Operation *op) { return nullptr; });
+          .Default(nullptr);
   if (!newConv)
     return failure();
   for (auto attr : preservedAttrs)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0f317eac8fa41..cb6199f026e03 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
           [&](auto op) { return CombiningKind::MUL; })
       .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
       .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
-      .Default([&](auto op) { return std::nullopt; });
+      .Default(std::nullopt);
 }
 
 /// Check whether `outputOperand` is a reduction with a single combiner
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fddf37e0b..e6850890bf8fe 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {
                      vector::MaskedStoreOp, vector::TransferReadOp,
                      vector::TransferWriteOp>(
           [](auto op) { return op.getBase(); })
-      .Default([](auto) { return Value{}; });
+      .Default(nullptr);
 }
 
 template <typename T>
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 4ebd90dbcc1d5..d380c46f7fbee 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
                              ? forOp.getInitArgs()[opResult.getResultNumber()]
                              : Value();
                 })
-                .Default([&](auto op) { return Value(); });
+                .Default(nullptr);
   }
   return false;
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0c8114d5e957e..938952ed273cd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
       llvm::TypeSwitch<Type, Type>(getType())
           .Case<spirv::CooperativeMatrixType>(
               [](auto coopType) { return coopType.getElementType(); })
-          .Default([](Type) { return nullptr; });
+          .Default(nullptr);
 
   // Case 1. -- matrices.
   if (coopElementType) {
@@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
       llvm::TypeSwitch<Type, Type>(getMatrix().getType())
           .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
               [](auto matrixType) { return matrixType.getElementType(); })
-          .Default([](Type) { return nullptr; });
+          .Default(nullptr);
 
   assert(elementType && "Unhandled type");
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f895807ea1d18..d1e275d590f78 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {
           return *elementSize * type.getNumElements();
         return std::nullopt;
       })
-      .Default(std::optional<int64_t>());
+      .Default(std::nullopt);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 88e1ab6ab1e4d..cb9b7f6ec2fd2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
   return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
       .Case<vector::ReductionOp, vector::TransposeOp>(
           [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
-      .Default([](Operation *) { return std::nullopt; });
+      .Default(std::nullopt);
 }
 
 LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 69e649d2eebe8..bc4f5a5ac7f23 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {
               return constantFoldPadOp<llvm::APInt>(
                   rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
             })
-            .Default(Value());
+            .Default(nullptr);
 
     if (!newOp)
       return rewriter.notifyMatchFailure(padTensorOp,
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index b31377e0de3e9..0f1bf83d1987b 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -56,7 +56,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
         StringRef value = init->getValue();
         return value.empty() ? std::optional<StringRef>() : value;
       })
-      .Default([](auto *) { return std::nullopt; });
+      .Default(std::nullopt);
 }
 
 // Return the C++ type for this type (which may just be ::mlir::Type).
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index eeb87253e5eb8..e3bcf2749be13 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) {
             .Case<>([&](LLVM::DIGlobalVariableAttr global) {
               return translate(global);
             })
-            .Default([&](Attribute attr) { return nullptr; });
+            .Default(nullptr);
     return metadata;
   };
   return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()),
@@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) {
             .Case([&](LLVM::DILocalVariableAttr local) {
               return translate(local);
             })
-            .Case<>([&](LLVM::DIGlobalVariableAttr global) {
+            .Case([&](LLVM::DIGlobalVariableAttr global) {
               return translate(global);
             })
-            .Default([&](Attribute attr) { return nullptr; });
+            .Default(nullptr);
     return metadata;
   };
   return llvm::DIGenericSubrange::get(llvmCtx,



More information about the Mlir-commits mailing list