[clang] [CIR] Add cir-simplify pass (PR #138317)
Morris Hafner via cfe-commits
cfe-commits at lists.llvm.org
Fri May 2 11:06:41 PDT 2025
https://github.com/mmha created https://github.com/llvm/llvm-project/pull/138317
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
>From 2b6ecd77c4fac0a2982172294d12ae858f0a2b34 Mon Sep 17 00:00:00 2001
From: Morris Hafner <mhafner at nvidia.com>
Date: Fri, 2 May 2025 20:05:40 +0200
Subject: [PATCH] [CIR] Add cir-simplify pass
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
---
clang/include/clang/CIR/CIRToCIRPasses.h | 3 +-
.../clang/CIR/Dialect/IR/CIRDialect.td | 2 +
clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 +
clang/include/clang/CIR/Dialect/Passes.h | 1 +
clang/include/clang/CIR/Dialect/Passes.td | 14 ++
.../clang/CIR/FrontendAction/CIRGenAction.h | 2 +-
clang/include/clang/CIR/MissingFeatures.h | 1 -
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 30 +++
.../Dialect/Transforms/CIRCanonicalize.cpp | 3 +-
.../CIR/Dialect/Transforms/CIRSimplify.cpp | 184 ++++++++++++++++++
.../lib/CIR/Dialect/Transforms/CMakeLists.txt | 1 +
clang/lib/CIR/FrontendAction/CIRGenAction.cpp | 21 +-
clang/lib/CIR/Lowering/CIRPasses.cpp | 6 +-
clang/test/CIR/Transforms/select.cir | 60 ++++++
clang/test/CIR/Transforms/ternary-fold.cir | 60 ++++++
clang/tools/cir-opt/cir-opt.cpp | 3 +
16 files changed, 378 insertions(+), 15 deletions(-)
create mode 100644 clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
create mode 100644 clang/test/CIR/Transforms/select.cir
create mode 100644 clang/test/CIR/Transforms/ternary-fold.cir
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
- bool enableVerifier);
+ bool enableVerifier,
+ bool enableCIRSimplify);
} // namespace cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..818a605ab74d3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
+ let hasConstantMaterializer = 1;
+
let extraClassDeclaration = [{
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9215543ab67e6..8205718e0fc30 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..46fa97da04ca1 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def CIRSimplify : Pass<"cir-simplify"> {
+ let summary = "Performs CIR simplification and code optimization";
+ let description = [{
+ The pass performs code simplification and optimization on CIR.
+
+ Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
+ transformations that could significantly affect CIR-to-source fidelity.
+ Example transformations performed in this pass include ternary folding,
+ code hoisting, etc.
+ }];
+ let constructor = "mlir::createCIRSimplifyPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+}
+
def HoistAllocas : Pass<"cir-hoist-allocas"> {
let summary = "Hoist allocas to the entry of the function";
let description = [{
diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
index 99495f4718c5f..b52166b58b882 100644
--- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
+++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
public:
~CIRGenAction() override;
- OutputType Action;
+ OutputType action;
};
class EmitCIRAction : public CIRGenAction {
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 3db13278261e6..b26144095792d 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -199,7 +199,6 @@ struct MissingFeatures {
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
static bool ptrStrideOp() { return false; }
- static bool selectOp() { return false; }
static bool switchOp() { return false; }
static bool ternaryOp() { return false; }
static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f5d6a424a71f6..5356630ece196 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
addInterfaces<CIROpAsmDialectInterface>();
}
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ return builder.create<cir::ConstantOp>(loc, type,
+ mlir::cast<mlir::TypedAttr>(value));
+}
+
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute condition = adaptor.getCondition();
+ if (condition) {
+ bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+ return conditionValue ? getTrueValue() : getFalseValue();
+ }
+
+ // cir.select if %0 then x else x -> x
+ mlir::Attribute trueValue = adaptor.getTrueValue();
+ mlir::Attribute falseValue = adaptor.getFalseValue();
+ if (trueValue == falseValue)
+ return trueValue;
+ if (getTrueValue() == getFalseValue())
+ return getTrueValue();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..442801d062638
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,184 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+/// %0 = cir.ternary (%condition, true {
+/// %1 = cir.const ...
+/// cir.yield %1
+/// } false {
+/// cir.yield %2
+/// })
+///
+/// into the following sequence of operations:
+///
+/// %1 = cir.const ...
+/// %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+ using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TernaryOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumResults() != 1)
+ return mlir::failure();
+
+ if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+ !isSimpleTernaryBranch(op.getFalseRegion()))
+ return mlir::failure();
+
+ cir::YieldOp trueBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+ cir::YieldOp falseBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+ mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+ mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+ rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+ rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+ rewriter.eraseOp(trueBranchYieldOp);
+ rewriter.eraseOp(falseBranchYieldOp);
+ rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+ falseValue);
+
+ return mlir::success();
+ }
+
+private:
+ bool isSimpleTernaryBranch(mlir::Region ®ion) const {
+ if (!region.hasOneBlock())
+ return false;
+
+ mlir::Block &onlyBlock = region.front();
+ mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+ // The region/block could only contain at most 2 operations.
+ if (ops.size() > 2)
+ return false;
+
+ if (ops.size() == 1) {
+ // The region/block only contain a cir.yield operation.
+ return true;
+ }
+
+ // Check whether the region/block contains a cir.const followed by a
+ // cir.yield that yields the value.
+ auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+ auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+ yieldOp.getArgs()[0].getDefiningOp());
+ return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+ }
+};
+
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+ using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SelectOp op,
+ PatternRewriter &rewriter) const final {
+ mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+ mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+ auto trueValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+ auto falseValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+ if (!trueValueConstOp || !falseValueConstOp)
+ return mlir::failure();
+
+ auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+ auto falseValue =
+ mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+ if (!trueValue || !falseValue)
+ return mlir::failure();
+
+ // cir.select if %0 then #true else #false -> %0
+ if (trueValue.getValue() && !falseValue.getValue()) {
+ rewriter.replaceAllUsesWith(op, op.getCondition());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // cir.select if %0 then #false else #true -> cir.unary not %0
+ if (!trueValue.getValue() && falseValue.getValue()) {
+ rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+ op.getCondition());
+ return mlir::success();
+ }
+
+ return mlir::failure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+ using CIRSimplifyBase::CIRSimplifyBase;
+
+ void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+ // clang-format off
+ patterns.add<
+ SimplifyTernary,
+ SimplifySelect
+ >(patterns.getContext());
+ // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+ // Collect rewrite patterns.
+ RewritePatternSet patterns(&getContext());
+ populateMergeCleanupPatterns(patterns);
+
+ // Collect operations to apply patterns.
+ llvm::SmallVector<Operation *, 16> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<TernaryOp, SelectOp>(op))
+ ops.push_back(op);
+ });
+
+ // Apply patterns.
+ if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+ signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+ return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
+ CIRSimplify.cpp
FlattenCFG.cpp
HoistAllocas.cpp
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..570403dda9d9f 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
+ CodeGenOptions &codeGenOptions;
public:
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
+ CodeGenOptions &codeGenOptions,
std::unique_ptr<raw_pwrite_stream> OS)
: Action(Action), CI(CI), OutputStream(std::move(OS)),
FS(&CI.getVirtualFileSystem()),
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
CI.getCodeGenOpts())),
- FEOptions(CI.getFrontendOpts()) {}
+ FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {}
void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
@@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
if (!FEOptions.ClangIRDisablePasses) {
// Setup and run CIR pipeline.
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
- !FEOptions.ClangIRDisableCIRVerifier)
+ !FEOptions.ClangIRDisableCIRVerifier,
+ codeGenOptions.OptimizationLevel > 0)
.failed()) {
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
return;
@@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
void CIRGenConsumer::anchor() {}
CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)
- : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {}
+ : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {}
CIRGenAction::~CIRGenAction() { MLIRMod.release(); }
@@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
}
std::unique_ptr<ASTConsumer>
-CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
- std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();
+CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) {
+ std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream();
- if (!Out)
- Out = getOutputStream(CI, InFile, Action);
+ if (!out)
+ out = getOutputStream(ci, inFile, action);
- auto Result =
- std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+ auto Result = std::make_unique<cir::CIRGenConsumer>(
+ action, ci, ci.getCodeGenOpts(), std::move(out));
return Result;
}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirContext,
clang::ASTContext &astContext,
- bool enableVerifier) {
+ bool enableVerifier,
+ bool enableCIRSimplify) {
llvm::TimeTraceScope scope("CIR To CIR Passes");
mlir::PassManager pm(&mlirContext);
pm.addPass(mlir::createCIRCanonicalizePass());
+ if (enableCIRSimplify)
+ pm.addPass(mlir::createCIRSimplifyPass());
+
pm.enableVerifier(enableVerifier);
(void)mlir::applyPassManagerCLOptions(pm);
return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG1]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.const #cir.int<42> : !s32i
+ %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.return %[[#A]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.const #cir.bool<false> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
+ // CHECK-NEXT: }
+
+ cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.const #cir.bool<true> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
+ // CHECK-NEXT: cir.return %[[#A]] : !cir.bool
+ // CHECK-NEXT: }
+}
diff --git a/clang/test/CIR/Transforms/ternary-fold.cir b/clang/test/CIR/Transforms/ternary-fold.cir
new file mode 100644
index 0000000000000..72ba4815b2db2
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary-fold.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.ternary (%0, true {
+ cir.yield %arg0 : !s32i
+ }, false {
+ cir.yield %arg1 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
+ %0 = cir.ternary (%arg0, true {
+ %1 = cir.const #cir.int<42> : !s32i
+ cir.yield %1 : !s32i
+ }, false {
+ cir.yield %arg1 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
+ // CHECK-NEXT: cir.return %[[#B]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+ %1 = cir.ternary (%arg0, true {
+ %2 = cir.const #cir.int<42> : !s32i
+ cir.yield %2 : !s32i
+ }, false {
+ %3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+ cir.yield %3 : !s32i
+ }) : (!cir.bool) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
+ // CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true {
+ // CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.yield %[[#C]] : !s32i
+ // CHECK-NEXT: }, false {
+ // CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
+ // CHECK-NEXT: cir.yield %[[#D]] : !s32i
+ // CHECK-NEXT: }) : (!cir.bool) -> !s32i
+ // CHECK-NEXT: cir.return %[[#B]] : !s32i
+ // CHECK-NEXT: }
+}
diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp
index e50fa70582966..0e20b97feced8 100644
--- a/clang/tools/cir-opt/cir-opt.cpp
+++ b/clang/tools/cir-opt/cir-opt.cpp
@@ -37,6 +37,9 @@ int main(int argc, char **argv) {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::createCIRCanonicalizePass();
});
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
+ return mlir::createCIRSimplifyPass();
+ });
mlir::PassPipelineRegistration<CIRToLLVMPipelineOptions> pipeline(
"cir-to-llvm", "",
More information about the cfe-commits
mailing list