[clang] 2eb6545 - [CIR] Add cir-simplify pass (#138317)
via cfe-commits
cfe-commits at lists.llvm.org
Wed May 7 09:50:42 PDT 2025
Author: Morris Hafner
Date: 2025-05-07T18:50:39+02:00
New Revision: 2eb6545b3ecb567a85d9114dab69a1455c7a032c
URL: https://github.com/llvm/llvm-project/commit/2eb6545b3ecb567a85d9114dab69a1455c7a032c
DIFF: https://github.com/llvm/llvm-project/commit/2eb6545b3ecb567a85d9114dab69a1455c7a032c.diff
LOG: [CIR] Add cir-simplify pass (#138317)
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It
also adds the SelectOp folder and adds the constant materializer for the
CIR dialect.
Added:
clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
clang/test/CIR/Transforms/select.cir
clang/test/CIR/Transforms/ternary-fold.cir
Modified:
clang/include/clang/CIR/CIRToCIRPasses.h
clang/include/clang/CIR/Dialect/IR/CIRDialect.td
clang/include/clang/CIR/Dialect/IR/CIROps.td
clang/include/clang/CIR/Dialect/Passes.h
clang/include/clang/CIR/Dialect/Passes.td
clang/include/clang/CIR/MissingFeatures.h
clang/lib/CIR/Dialect/IR/CIRDialect.cpp
clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
clang/lib/CIR/FrontendAction/CIRGenAction.cpp
clang/lib/CIR/Lowering/CIRPasses.cpp
clang/tools/cir-opt/cir-opt.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
- bool enableVerifier);
+ bool enableVerifier,
+ bool enableCIRSimplify);
} // namespace cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..52e32eedf774d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,13 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
+ // Enable constant materialization for the CIR dialect. This generates a
+ // declaration for the cir::CIRDialect::materializeConstant function. This
+ // hook is necessary for canonicalization to properly handle attributes
+ // returned by fold methods, allowing them to be materialized as constant
+ // operations in the IR.
+ let hasConstantMaterializer = 1;
+
let extraClassDeclaration = [{
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 422c89c4f9391..8d01db03cb3fa 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..de775e69f0073 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,25 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def CIRSimplify : Pass<"cir-simplify"> {
+ let summary = "Performs CIR simplification and code optimization";
+ let description = [{
+ The pass performs semantics-preserving code simplifications and optimizations
+ on CIR while maintaining strict program correctness.
+
+ Unlike the `cir-canonicalize` pass, these transformations may reduce the IR's
+ structural similarity to the original source code as a trade-off for improved
+ code quality. This can affect debugging fidelity by altering intermediate
+ representations of folded expressions, hoisted operations, and other
+ optimized constructs.
+
+ Example transformations include ternary expression folding and code hoisting
+ while preserving program semantics.
+ }];
+ let constructor = "mlir::createCIRSimplifyPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+}
+
def HoistAllocas : Pass<"cir-hoist-allocas"> {
let summary = "Hoist allocas to the entry of the function";
let description = [{
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index eb75a073d1817..06636cd6c554c 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -206,7 +206,6 @@ struct MissingFeatures {
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
static bool ptrStrideOp() { return false; }
- static bool selectOp() { return false; }
static bool switchOp() { return false; }
static bool ternaryOp() { return false; }
static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 6b144149b41c9..b131edaf403ed 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
addInterfaces<CIROpAsmDialectInterface>();
}
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ return builder.create<cir::ConstantOp>(loc, type,
+ mlir::cast<mlir::TypedAttr>(value));
+}
+
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute condition = adaptor.getCondition();
+ if (condition) {
+ bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+ return conditionValue ? getTrueValue() : getFalseValue();
+ }
+
+ // cir.select if %0 then x else x -> x
+ mlir::Attribute trueValue = adaptor.getTrueValue();
+ mlir::Attribute falseValue = adaptor.getFalseValue();
+ if (trueValue == falseValue)
+ return trueValue;
+ if (getTrueValue() == getFalseValue())
+ return getTrueValue();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..b969569b0081c
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,202 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+/// %0 = ...
+/// %1 = cir.ternary (%condition, true {
+/// %2 = cir.const ...
+/// cir.yield %2
+/// } false {
+/// cir.yield %0
+///
+/// into the following sequence of operations:
+///
+/// %1 = cir.const ...
+/// %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+ using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TernaryOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumResults() != 1)
+ return mlir::failure();
+
+ if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+ !isSimpleTernaryBranch(op.getFalseRegion()))
+ return mlir::failure();
+
+ cir::YieldOp trueBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+ cir::YieldOp falseBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+ mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+ mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+ rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+ rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+ rewriter.eraseOp(trueBranchYieldOp);
+ rewriter.eraseOp(falseBranchYieldOp);
+ rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+ falseValue);
+
+ return mlir::success();
+ }
+
+private:
+ bool isSimpleTernaryBranch(mlir::Region ®ion) const {
+ if (!region.hasOneBlock())
+ return false;
+
+ mlir::Block &onlyBlock = region.front();
+ mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+ // The region/block could only contain at most 2 operations.
+ if (ops.size() > 2)
+ return false;
+
+ if (ops.size() == 1) {
+ // The region/block only contain a cir.yield operation.
+ return true;
+ }
+
+ // Check whether the region/block contains a cir.const followed by a
+ // cir.yield that yields the value.
+ auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+ auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+ yieldOp.getArgs()[0].getDefiningOp());
+ return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+ }
+};
+
+/// Simplify select operations with boolean constants into simpler forms.
+///
+/// This pattern simplifies select operations where both true and false values
+/// are boolean constants. Two specific cases are handled:
+///
+/// 1. When selecting between true and false based on a condition,
+/// the operation simplifies to just the condition itself:
+///
+/// %0 = cir.select if %condition then true else false
+/// ->
+/// (replaced with %condition directly)
+///
+/// 2. When selecting between false and true based on a condition,
+/// the operation simplifies to the logical negation of the condition:
+///
+/// %0 = cir.select if %condition then false else true
+/// ->
+/// %0 = cir.unary not %condition
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+ using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SelectOp op,
+ PatternRewriter &rewriter) const final {
+ mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+ mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+ auto trueValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+ auto falseValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+ if (!trueValueConstOp || !falseValueConstOp)
+ return mlir::failure();
+
+ auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+ auto falseValue =
+ mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+ if (!trueValue || !falseValue)
+ return mlir::failure();
+
+ // cir.select if %0 then #true else #false -> %0
+ if (trueValue.getValue() && !falseValue.getValue()) {
+ rewriter.replaceAllUsesWith(op, op.getCondition());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // cir.select if %0 then #false else #true -> cir.unary not %0
+ if (!trueValue.getValue() && falseValue.getValue()) {
+ rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+ op.getCondition());
+ return mlir::success();
+ }
+
+ return mlir::failure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+ using CIRSimplifyBase::CIRSimplifyBase;
+
+ void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+ // clang-format off
+ patterns.add<
+ SimplifyTernary,
+ SimplifySelect
+ >(patterns.getContext());
+ // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+ // Collect rewrite patterns.
+ RewritePatternSet patterns(&getContext());
+ populateMergeCleanupPatterns(patterns);
+
+ // Collect operations to apply patterns.
+ llvm::SmallVector<Operation *, 16> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<TernaryOp, SelectOp>(op))
+ ops.push_back(op);
+ });
+
+ // Apply patterns.
+ if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+ signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+ return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
+ CIRSimplify.cpp
FlattenCFG.cpp
HoistAllocas.cpp
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..cc65c93f5f16b 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
+ CodeGenOptions &CGO;
public:
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
- std::unique_ptr<raw_pwrite_stream> OS)
+ CodeGenOptions &CGO, std::unique_ptr<raw_pwrite_stream> OS)
: Action(Action), CI(CI), OutputStream(std::move(OS)),
FS(&CI.getVirtualFileSystem()),
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
CI.getCodeGenOpts())),
- FEOptions(CI.getFrontendOpts()) {}
+ FEOptions(CI.getFrontendOpts()), CGO(CGO) {}
void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
@@ -102,7 +103,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
if (!FEOptions.ClangIRDisablePasses) {
// Setup and run CIR pipeline.
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
- !FEOptions.ClangIRDisableCIRVerifier)
+ !FEOptions.ClangIRDisableCIRVerifier,
+ CGO.OptimizationLevel > 0)
.failed()) {
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
return;
@@ -168,8 +170,8 @@ CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
if (!Out)
Out = getOutputStream(CI, InFile, Action);
- auto Result =
- std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+ auto Result = std::make_unique<cir::CIRGenConsumer>(
+ Action, CI, CI.getCodeGenOpts(), std::move(Out));
return Result;
}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirContext,
clang::ASTContext &astContext,
- bool enableVerifier) {
+ bool enableVerifier,
+ bool enableCIRSimplify) {
llvm::TimeTraceScope scope("CIR To CIR Passes");
mlir::PassManager pm(&mlirContext);
pm.addPass(mlir::createCIRCanonicalizePass());
+ if (enableCIRSimplify)
+ pm.addPass(mlir::createCIRSimplifyPass());
+
pm.enableVerifier(enableVerifier);
(void)mlir::applyPassManagerCLOptions(pm);
return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG1]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.const #cir.int<42> : !s32i
+ %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.return %[[#A]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.const #cir.bool<false> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
+ // CHECK-NEXT: }
+
+ cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.const #cir.bool<true> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
+ // CHECK-NEXT: cir.return %[[#A]] : !cir.bool
+ // CHECK-NEXT: }
+}
diff --git a/clang/test/CIR/Transforms/ternary-fold.cir b/clang/test/CIR/Transforms/ternary-fold.cir
new file mode 100644
index 0000000000000..1192a0ce29424
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary-fold.cir
@@ -0,0 +1,76 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.ternary (%0, true {
+ cir.yield %arg0 : !s32i
+ }, false {
+ cir.yield %arg1 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
+ %0 = cir.ternary (%arg0, true {
+ %1 = cir.const #cir.int<42> : !s32i
+ cir.yield %1 : !s32i
+ }, false {
+ cir.yield %arg1 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
+ // CHECK-NEXT: cir.return %[[#B]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_ternary_false_const(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
+ %0 = cir.ternary (%arg0, true {
+ cir.yield %arg1 : !s32i
+ }, false {
+ %1 = cir.const #cir.int<24> : !s32i
+ cir.yield %1 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func @simplify_ternary_false_const(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<24> : !s32i
+ // CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[ARG1]] else %[[#A]] : (!cir.bool, !s32i, !s32i) -> !s32i
+ // CHECK-NEXT: cir.return %[[#B]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+ %1 = cir.ternary (%arg0, true {
+ %2 = cir.const #cir.int<42> : !s32i
+ cir.yield %2 : !s32i
+ }, false {
+ %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.yield %3 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+ // CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true {
+ // CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.yield %[[#C]] : !s32i
+ // CHECK-NEXT: }, false {
+ // CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
+ // CHECK-NEXT: cir.yield %[[#D]] : !s32i
+ // CHECK-NEXT: }) : (!cir.bool) -> !s32i
+ // CHECK-NEXT: cir.return %[[#B]] : !s32i
+ // CHECK-NEXT: }
+}
diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp
index e50fa70582966..0e20b97feced8 100644
--- a/clang/tools/cir-opt/cir-opt.cpp
+++ b/clang/tools/cir-opt/cir-opt.cpp
@@ -37,6 +37,9 @@ int main(int argc, char **argv) {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createCIRCanonicalizePass();
});
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
+ return mlir::createCIRSimplifyPass();
+ });
mlir::PassPipelineRegistration<CIRToLLVMPipelineOptions> pipeline(
"cir-to-llvm", "",
More information about the cfe-commits
mailing list