[clang] 2eb6545 - [CIR] Add cir-simplify pass (#138317)

via cfe-commits cfe-commits at lists.llvm.org
Wed May 7 09:50:42 PDT 2025


Author: Morris Hafner
Date: 2025-05-07T18:50:39+02:00
New Revision: 2eb6545b3ecb567a85d9114dab69a1455c7a032c

URL: https://github.com/llvm/llvm-project/commit/2eb6545b3ecb567a85d9114dab69a1455c7a032c
DIFF: https://github.com/llvm/llvm-project/commit/2eb6545b3ecb567a85d9114dab69a1455c7a032c.diff

LOG: [CIR] Add cir-simplify pass (#138317)

This patch adds the cir-simplify pass for SelectOp and TernaryOp. It
also adds the SelectOp folder and adds the constant materializer for the
CIR dialect.

Added: 
    clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
    clang/test/CIR/Transforms/select.cir
    clang/test/CIR/Transforms/ternary-fold.cir

Modified: 
    clang/include/clang/CIR/CIRToCIRPasses.h
    clang/include/clang/CIR/Dialect/IR/CIRDialect.td
    clang/include/clang/CIR/Dialect/IR/CIROps.td
    clang/include/clang/CIR/Dialect/Passes.h
    clang/include/clang/CIR/Dialect/Passes.td
    clang/include/clang/CIR/MissingFeatures.h
    clang/lib/CIR/Dialect/IR/CIRDialect.cpp
    clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
    clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
    clang/lib/CIR/FrontendAction/CIRGenAction.cpp
    clang/lib/CIR/Lowering/CIRPasses.cpp
    clang/tools/cir-opt/cir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirCtx,
                                       clang::ASTContext &astCtx,
-                                      bool enableVerifier);
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify);
 
 } // namespace cir
 

diff  --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..52e32eedf774d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,13 @@ def CIR_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 0;
   let useDefaultTypePrinterParser = 0;
 
+  // Enable constant materialization for the CIR dialect. This generates a
+  // declaration for the cir::CIRDialect::materializeConstant function. This
+  // hook is necessary for canonicalization to properly handle attributes
+  // returned by fold methods, allowing them to be materialized as constant
+  // operations in the IR.
+  let hasConstantMaterializer = 1;
+
   let extraClassDeclaration = [{
     static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
 

diff  --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 422c89c4f9391..8d01db03cb3fa 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
       qualified(type($false_value))
     `)` `->` qualified(type($result)) attr-dict
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
 
 std::unique_ptr<Pass> createCIRCanonicalizePass();
 std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
 std::unique_ptr<Pass> createHoistAllocasPass();
 
 void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

diff  --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..de775e69f0073 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,25 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
   let dependentDialects = ["cir::CIRDialect"];
 }
 
+def CIRSimplify : Pass<"cir-simplify"> {
+  let summary = "Performs CIR simplification and code optimization";
+  let description = [{
+    The pass performs semantics-preserving code simplifications and optimizations
+    on CIR while maintaining strict program correctness. 
+    
+    Unlike the `cir-canonicalize` pass, these transformations may reduce the IR's
+    structural similarity to the original source code as a trade-off for improved
+    code quality. This can affect debugging fidelity by altering intermediate
+    representations of folded expressions, hoisted operations, and other 
+    optimized constructs.
+    
+    Example transformations include ternary expression folding and code hoisting
+    while preserving program semantics.
+  }];
+  let constructor = "mlir::createCIRSimplifyPass()";
+  let dependentDialects = ["cir::CIRDialect"];
+}
+
 def HoistAllocas : Pass<"cir-hoist-allocas"> {
   let summary = "Hoist allocas to the entry of the function";
   let description = [{

diff  --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index eb75a073d1817..06636cd6c554c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -206,7 +206,6 @@ struct MissingFeatures {
   static bool labelOp() { return false; }
   static bool ptrDiffOp() { return false; }
   static bool ptrStrideOp() { return false; }
-  static bool selectOp() { return false; }
   static bool switchOp() { return false; }
   static bool ternaryOp() { return false; }
   static bool tryOp() { return false; }

diff  --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 6b144149b41c9..b131edaf403ed 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
   addInterfaces<CIROpAsmDialectInterface>();
 }
 
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+                                                mlir::Attribute value,
+                                                mlir::Type type,
+                                                mlir::Location loc) {
+  return builder.create<cir::ConstantOp>(loc, type,
+                                         mlir::cast<mlir::TypedAttr>(value));
+}
+
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
     result.addTypes(TypeRange{yield.getOperandTypes().front()});
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+  mlir::Attribute condition = adaptor.getCondition();
+  if (condition) {
+    bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+    return conditionValue ? getTrueValue() : getFalseValue();
+  }
+
+  // cir.select if %0 then x else x -> x
+  mlir::Attribute trueValue = adaptor.getTrueValue();
+  mlir::Attribute falseValue = adaptor.getFalseValue();
+  if (trueValue == falseValue)
+    return trueValue;
+  if (getTrueValue() == getFalseValue())
+    return getTrueValue();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ShiftOp
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
   getOperation()->walk([&](Operation *op) {
     assert(!cir::MissingFeatures::switchOp());
     assert(!cir::MissingFeatures::tryOp());
-    assert(!cir::MissingFeatures::selectOp());
     assert(!cir::MissingFeatures::complexCreateOp());
     assert(!cir::MissingFeatures::complexRealOp());
     assert(!cir::MissingFeatures::complexImagOp());
     assert(!cir::MissingFeatures::callOp());
     // CastOp and UnaryOp are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
-    if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
       ops.push_back(op);
   });
 

diff  --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..b969569b0081c
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,202 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+///   %0 = ...
+///   %1 = cir.ternary (%condition, true {
+///     %2 = cir.const ...
+///     cir.yield %2
+///   } false {
+///     cir.yield %0
+///
+/// into the following sequence of operations:
+///
+///   %1 = cir.const ...
+///   %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+  using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TernaryOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return mlir::failure();
+
+    if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+        !isSimpleTernaryBranch(op.getFalseRegion()))
+      return mlir::failure();
+
+    cir::YieldOp trueBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+    cir::YieldOp falseBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+    mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+    mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+    rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+    rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+    rewriter.eraseOp(trueBranchYieldOp);
+    rewriter.eraseOp(falseBranchYieldOp);
+    rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+                                               falseValue);
+
+    return mlir::success();
+  }
+
+private:
+  bool isSimpleTernaryBranch(mlir::Region &region) const {
+    if (!region.hasOneBlock())
+      return false;
+
+    mlir::Block &onlyBlock = region.front();
+    mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+    // The region/block could only contain at most 2 operations.
+    if (ops.size() > 2)
+      return false;
+
+    if (ops.size() == 1) {
+      // The region/block only contain a cir.yield operation.
+      return true;
+    }
+
+    // Check whether the region/block contains a cir.const followed by a
+    // cir.yield that yields the value.
+    auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+    auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+        yieldOp.getArgs()[0].getDefiningOp());
+    return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+  }
+};
+
+/// Simplify select operations with boolean constants into simpler forms.
+///
+/// This pattern simplifies select operations where both true and false values
+/// are boolean constants. Two specific cases are handled:
+///
+/// 1. When selecting between true and false based on a condition,
+///    the operation simplifies to just the condition itself:
+///
+///    %0 = cir.select if %condition then true else false
+///    ->
+///    (replaced with %condition directly)
+///
+/// 2. When selecting between false and true based on a condition,
+///    the operation simplifies to the logical negation of the condition:
+///
+///    %0 = cir.select if %condition then false else true
+///    ->
+///    %0 = cir.unary not %condition
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+  using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SelectOp op,
+                                PatternRewriter &rewriter) const final {
+    mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+    mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+    auto trueValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+    auto falseValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+    if (!trueValueConstOp || !falseValueConstOp)
+      return mlir::failure();
+
+    auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+    auto falseValue =
+        mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+    if (!trueValue || !falseValue)
+      return mlir::failure();
+
+    // cir.select if %0 then #true else #false -> %0
+    if (trueValue.getValue() && !falseValue.getValue()) {
+      rewriter.replaceAllUsesWith(op, op.getCondition());
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // cir.select if %0 then #false else #true -> cir.unary not %0
+    if (!trueValue.getValue() && falseValue.getValue()) {
+      rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+                                                op.getCondition());
+      return mlir::success();
+    }
+
+    return mlir::failure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+  using CIRSimplifyBase::CIRSimplifyBase;
+
+  void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+  // clang-format off
+  patterns.add<
+    SimplifyTernary,
+    SimplifySelect
+  >(patterns.getContext());
+  // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+  // Collect rewrite patterns.
+  RewritePatternSet patterns(&getContext());
+  populateMergeCleanupPatterns(patterns);
+
+  // Collect operations to apply patterns.
+  llvm::SmallVector<Operation *, 16> ops;
+  getOperation()->walk([&](Operation *op) {
+    if (isa<TernaryOp, SelectOp>(op))
+      ops.push_back(op);
+  });
+
+  // Apply patterns.
+  if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+    signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+  return std::make_unique<CIRSimplifyPass>();
+}

diff  --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_clang_library(MLIRCIRTransforms
   CIRCanonicalize.cpp
+  CIRSimplify.cpp
   FlattenCFG.cpp
   HoistAllocas.cpp
 

diff  --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..cc65c93f5f16b 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
   IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
   std::unique_ptr<CIRGenerator> Gen;
   const FrontendOptions &FEOptions;
+  CodeGenOptions &CGO;
 
 public:
   CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
-                 std::unique_ptr<raw_pwrite_stream> OS)
+                 CodeGenOptions &CGO, std::unique_ptr<raw_pwrite_stream> OS)
       : Action(Action), CI(CI), OutputStream(std::move(OS)),
         FS(&CI.getVirtualFileSystem()),
         Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
                                            CI.getCodeGenOpts())),
-        FEOptions(CI.getFrontendOpts()) {}
+        FEOptions(CI.getFrontendOpts()), CGO(CGO) {}
 
   void Initialize(ASTContext &Ctx) override {
     assert(!Context && "initialized multiple times");
@@ -102,7 +103,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
     if (!FEOptions.ClangIRDisablePasses) {
       // Setup and run CIR pipeline.
       if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
-                            !FEOptions.ClangIRDisableCIRVerifier)
+                            !FEOptions.ClangIRDisableCIRVerifier,
+                            CGO.OptimizationLevel > 0)
               .failed()) {
         CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
         return;
@@ -168,8 +170,8 @@ CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
   if (!Out)
     Out = getOutputStream(CI, InFile, Action);
 
-  auto Result =
-      std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+  auto Result = std::make_unique<cir::CIRGenConsumer>(
+      Action, CI, CI.getCodeGenOpts(), std::move(Out));
 
   return Result;
 }

diff  --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirContext,
                                       clang::ASTContext &astContext,
-                                      bool enableVerifier) {
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify) {
 
   llvm::TimeTraceScope scope("CIR To CIR Passes");
 
   mlir::PassManager pm(&mlirContext);
   pm.addPass(mlir::createCIRCanonicalizePass());
 
+  if (enableCIRSimplify)
+    pm.addPass(mlir::createCIRSimplifyPass());
+
   pm.enableVerifier(enableVerifier);
   (void)mlir::applyPassManagerCLOptions(pm);
   return pm.run(theModule);

diff  --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<false> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG1]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+    %0 = cir.const #cir.int<42> : !s32i
+    %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.const #cir.int<42> : !s32i
+  // CHECK-NEXT:   cir.return %[[#A]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.const #cir.bool<false> : !cir.bool
+    %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+    cir.return %2 : !cir.bool
+  }
+
+  //      CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !cir.bool
+  // CHECK-NEXT: }
+
+  cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
+    %0 = cir.const #cir.bool<false> : !cir.bool
+    %1 = cir.const #cir.bool<true> : !cir.bool
+    %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+    cir.return %2 : !cir.bool
+  }
+
+  //      CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+  // CHECK-NEXT:   %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
+  // CHECK-NEXT:   cir.return %[[#A]] : !cir.bool
+  // CHECK-NEXT: }
+}

diff  --git a/clang/test/CIR/Transforms/ternary-fold.cir b/clang/test/CIR/Transforms/ternary-fold.cir
new file mode 100644
index 0000000000000..1192a0ce29424
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary-fold.cir
@@ -0,0 +1,76 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<false> : !cir.bool
+    %1 = cir.ternary (%0, true {
+      cir.yield %arg0 : !s32i
+    }, false {
+      cir.yield %arg1 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
+    %0 = cir.ternary (%arg0, true {
+      %1 = cir.const #cir.int<42> : !s32i
+      cir.yield %1 : !s32i
+    }, false {
+      cir.yield %arg1 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.return %0 : !s32i
+  }
+
+  //      CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.const #cir.int<42> : !s32i
+  // CHECK-NEXT:   %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
+  // CHECK-NEXT:   cir.return %[[#B]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @simplify_ternary_false_const(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
+    %0 = cir.ternary (%arg0, true {
+      cir.yield %arg1 : !s32i
+    }, false {
+      %1 = cir.const #cir.int<24> : !s32i
+      cir.yield %1 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.return %0 : !s32i
+  }
+
+  //      CHECK: cir.func @simplify_ternary_false_const(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.const #cir.int<24> : !s32i
+  // CHECK-NEXT:   %[[#B:]] = cir.select if %[[ARG0]] then %[[ARG1]] else %[[#A]] : (!cir.bool, !s32i, !s32i) -> !s32i
+  // CHECK-NEXT:   cir.return %[[#B]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+    %1 = cir.ternary (%arg0, true {
+      %2 = cir.const #cir.int<42> : !s32i
+      cir.yield %2 : !s32i
+    }, false {
+      %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+      cir.yield %3 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+  // CHECK-NEXT:   %[[#B:]] = cir.ternary(%[[ARG0]], true {
+  // CHECK-NEXT:     %[[#C:]] = cir.const #cir.int<42> : !s32i
+  // CHECK-NEXT:     cir.yield %[[#C]] : !s32i
+  // CHECK-NEXT:   }, false {
+  // CHECK-NEXT:     %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
+  // CHECK-NEXT:     cir.yield %[[#D]] : !s32i
+  // CHECK-NEXT:   }) : (!cir.bool) -> !s32i
+  // CHECK-NEXT:   cir.return %[[#B]] : !s32i
+  // CHECK-NEXT: }
+}

diff  --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp
index e50fa70582966..0e20b97feced8 100644
--- a/clang/tools/cir-opt/cir-opt.cpp
+++ b/clang/tools/cir-opt/cir-opt.cpp
@@ -37,6 +37,9 @@ int main(int argc, char **argv) {
   ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
     return mlir::createCIRCanonicalizePass();
   });
+  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
+    return mlir::createCIRSimplifyPass();
+  });
 
   mlir::PassPipelineRegistration<CIRToLLVMPipelineOptions> pipeline(
       "cir-to-llvm", "",


        


More information about the cfe-commits mailing list