[Mlir-commits] [mlir] [mlir][emitc] Add op modelling C expressions (PR #71631)
Gil Rapaport
llvmlistbot at llvm.org
Fri Nov 10 01:55:13 PST 2023
https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/71631
>From fd7b19c9c70db48b0eecb487eb458e8213b708f2 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Wed, 18 Oct 2023 17:48:29 +0300
Subject: [PATCH 01/21] [mlir][emitc] Add op modelling C expressions
Add an emitc.expression operation that models C expressions, and provide
transforms to form and fold expressions. The translator emits the body of
emitc.expression ops as a single C expression.
This expression is emitted by default as the RHS of an EmitC SSA value, but if
possible, expressions with a single use that is not another expression are
instead inlined. Specific expression's inlining can be fine tuned by lowering
passes and transforms.
---
.../include/mlir/Dialect/EmitC/CMakeLists.txt | 2 +
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 98 ++++++-
.../Dialect/EmitC/TransformOps/CMakeLists.txt | 6 +
.../EmitC/TransformOps/EmitCTransformOps.h | 49 ++++
.../EmitC/TransformOps/EmitCTransformOps.td | 70 +++++
.../Dialect/EmitC/Transforms/CMakeLists.txt | 5 +
.../mlir/Dialect/EmitC/Transforms/Passes.h | 35 +++
.../mlir/Dialect/EmitC/Transforms/Passes.td | 24 ++
.../Dialect/EmitC/Transforms/Transforms.h | 34 +++
mlir/include/mlir/InitAllExtensions.h | 2 +
mlir/include/mlir/InitAllPasses.h | 2 +
mlir/lib/Dialect/EmitC/CMakeLists.txt | 2 +
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 93 +++++++
.../Dialect/EmitC/TransformOps/CMakeLists.txt | 15 ++
.../EmitC/TransformOps/EmitCTransformOps.cpp | 114 ++++++++
.../Dialect/EmitC/Transforms/CMakeLists.txt | 16 ++
.../EmitC/Transforms/FormExpressions.cpp | 60 +++++
.../Dialect/EmitC/Transforms/Transforms.cpp | 117 ++++++++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 250 ++++++++++++++++--
mlir/test/Dialect/EmitC/invalid_ops.mlir | 59 ++++-
mlir/test/Dialect/EmitC/ops.mlir | 17 ++
mlir/test/Dialect/EmitC/transform-ops.mlir | 131 +++++++++
mlir/test/Target/Cpp/expressions.mlir | 212 +++++++++++++++
mlir/test/Target/Cpp/for.mlir | 22 +-
24 files changed, 1399 insertions(+), 36 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/EmitC/TransformOps/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
create mode 100644 mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
create mode 100644 mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
create mode 100644 mlir/lib/Dialect/EmitC/TransformOps/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
create mode 100644 mlir/test/Dialect/EmitC/transform-ops.mlir
create mode 100644 mlir/test/Target/Cpp/expressions.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt
index f33061b2d87cffc..cb1e9d01821a2cf 100644
--- a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt
@@ -1 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 2edeb6f8a9cf01e..4d522565c32826d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
//===----------------------------------------------------------------------===//
// EmitC op definitions
@@ -246,6 +247,85 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}
+def EmitC_ExpressionOp : EmitC_Op<"expression",
+ [HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">,
+ NoRegionArguments]> {
+ let summary = "Expression operation";
+ let description = [{
+ The `expression` operation returns a single SSA value which is yielded by
+ its single-basic-block region. The operation doesn't take any arguments.
+
+ As the operation is to be emitted as a C expression, the operations within
+ its body must form a single Def-Use tree of emitc ops whose result is
+ yielded by a terminating `yield`.
+
+ Example:
+
+ ```mlir
+ %r = emitc.expression : () -> i32 {
+ %0 = emitc.add %a, %b : (i32, i32) -> i32
+ %1 = emitc.call "foo"(%0) : () -> i32
+ %2 = emitc.add %c, %d : (i32, i32) -> i32
+ %3 = emitc.mul %1, %2 : (i32, i32) -> i32
+ yield %3
+ }
+ ```
+
+ May be emitted as
+
+ ```c++
+ int32_t v7 = foo(v1 + v2) * (v3 + v4);
+ ```
+
+ The operations allowed within expression body are emitc.add, emitc.apply,
+ emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and
+ emitc.sub.
+
+ When specified, the optional do_not_inline indicates that the expression is
+ to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
+ definition. Otherwise, the expression may be emitted inline, i.e. directly
+ at its use.
+ }];
+
+ let arguments = (ins UnitAttr:$do_not_inline);
+ let results = (outs AnyType);
+ let regions = (region SizedRegion<1>:$region);
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ static bool isCExpression(Operation &op) {
+ return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
+ emitc::CmpOp, emitc::DivOp, emitc::MulOp, emitc::RemOp,
+ emitc::SubOp>(op);
+ }
+ static bool hasSideEffects(Operation &op) {
+ assert(isCExpression(op) && "Expected a C operator");
+ // Conservatively assume calls to read and write memory.
+ if (isa<emitc::CallOp>(op))
+ return true;
+ // De-referencing reads modifiable memory.
+ auto applyOp = dyn_cast<emitc::ApplyOp>(op);
+ if (applyOp && applyOp.getApplicableOperator() == "*")
+ return true;
+ // Any operator using variables has a side effect of reading memory mutable by
+ // emitc::assign ops.
+ for (Value operand : op.getOperands()) {
+ Operation *def = operand.getDefiningOp();
+ if (def && isa<emitc::VariableOp>(def))
+ return true;
+ }
+ return false;
+ }
+ bool hasSideEffects() {
+ return llvm::any_of(getRegion().front().without_terminator(),
+ [](Operation &op) { return hasSideEffects(op); });
+ }
+ Operation *getRootOp();
+ }];
+}
+
def EmitC_ForOp : EmitC_Op<"for",
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
@@ -492,18 +572,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}
def EmitC_YieldOp : EmitC_Op<"yield",
- [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
+ [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
- "yield" terminates blocks within EmitC control-flow operations. Since
- control-flow constructs in C do not return values, this operation doesn't
- take any arguments.
+ "yield" terminates its parent EmitC op's region, optionally yielding
+ an SSA value. The semantics of how the values are yielded is defined by the
+ parent operation.
+ If "yield" has an operand, the operand must match the parent operation's
+ result. If the parent operation defines no values, then the "emitc.yield"
+ may be left out in the custom syntax and the builders will insert one
+ implicitly. Otherwise, it has to be present in the syntax to indicate which
+ value is yielded.
}];
- let arguments = (ins);
+ let arguments = (ins Optional<AnyType>:$result);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
- let assemblyFormat = [{ attr-dict }];
+ let hasVerifier = 1;
+ let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }];
}
def EmitC_IfOp : EmitC_Op<"if",
diff --git a/mlir/include/mlir/Dialect/EmitC/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/TransformOps/CMakeLists.txt
new file mode 100644
index 000000000000000..364398d2dc6b4eb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS EmitCTransformOps.td)
+mlir_tablegen(EmitCTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(EmitCTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIREmitCTransformOpsIncGen)
+
+add_mlir_doc(EmitCTransformOps EmitCLoopTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
new file mode 100644
index 000000000000000..5b31080c70d9fcb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
@@ -0,0 +1,49 @@
+//===- EmitCTransformOps.h - EmitC transformation ops ---------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_TRANSFORMOPS_EmitCTRANSFORMOPS_H
+#define MLIR_DIALECT_EMITC_TRANSFORMOPS_EmitCTRANSFORMOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace func {
+class FuncOp;
+} // namespace func
+namespace emitc {
+class ExpressionOp;
+class BinaryOp;
+} // namespace emitc
+} // namespace mlir
+
+namespace mlir {
+class DialectRegistry;
+
+namespace emitc {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace emitc
+} // namespace mlir
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace scf {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+} // namespace scf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_EmitC_TRANSFORMOPS_EMITCTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
new file mode 100644
index 000000000000000..1e442af2e4c1edb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
@@ -0,0 +1,70 @@
+//===- EmitCTransformOps.td - EmitC (loop) transformation ops --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef EMITC_TRANSFORM_OPS
+#define EMITC_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def ApplyExpressionPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.emitc.expressions",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply expression-related patterns.
+ }];
+
+ let arguments = (ins);
+ let assemblyFormat = [{ attr-dict }];
+}
+
+def CreateExpressionOp : Op<Transform_Dialect, "expression.create",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let summary = "Wrap C operators in emitc.expressions";
+ let description = [{
+ For each payload operation, constructs an `emitc.expression` wrapping that
+ operation and yielding the value it defines.
+
+ #### Return Modes
+
+ Produces a silenceable failure if the operand is not associated with emitc C
+ operator payload operations. Returns a single handle associated with the
+ generated `emitc.expression` ops.
+ }];
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
+def MatchExpressionRootOp : Op<Transform_Dialect, "expression.match.root",
+ [MemoryEffectsOpInterface,
+ NavigationTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Match emitc ops modelling C expressions.
+ For this to be doable, the payload must have a single use, where the user is
+ another emitc::expression and the payload must be movable to just before its
+ user. Produces a silenceable failure otherwise.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$results);
+
+ let assemblyFormat = [{
+ `in` $target attr-dict
+ `:` functional-type($target, results)
+ }];
+}
+
+#endif // EMITC_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..0b507d75fa07a6b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC)
+add_public_tablegen_target(MLIREmitCTransformsIncGen)
+
+add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
new file mode 100644
index 000000000000000..5cd27149d366ea0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
+#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace emitc {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+/// Creates an instance of the C-style expressions forming pass.
+std::unique_ptr<Pass> createFormExpressionsPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+} // namespace emitc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
new file mode 100644
index 000000000000000..fd083abc9571578
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -0,0 +1,24 @@
+//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
+#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def FormExpressions : Pass<"form-expressions"> {
+ let summary = "Form C-style expressions from C-operator ops";
+ let description = [{
+ The pass wraps emitc ops modelling C operators in emitc.expression ops and
+ then folds single-use expressions into their users where possible.
+ }];
+ let constructor = "mlir::emitc::createFormExpressionsPass()";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
+#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
new file mode 100644
index 000000000000000..73981df131681c8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -0,0 +1,34 @@
+//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace emitc {
+
+//===----------------------------------------------------------------------===//
+// Expression transforms.
+//===----------------------------------------------------------------------===//
+
+ExpressionOp createExpression(Operation *op, OpBuilder &builder);
+
+//===----------------------------------------------------------------------===//
+// Populate functions.
+//===----------------------------------------------------------------------===//
+
+/// Populates `patterns` with expression-related patterns.
+void populateExpressionPatterns(RewritePatternSet &patterns);
+
+} // namespace emitc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index c04ce850fb96f41..0f494f296cd6624 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -25,6 +25,7 @@
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
@@ -67,6 +68,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
bufferization::registerTransformDialectExtension(registry);
+ emitc::registerTransformDialectExtension(registry);
func::registerTransformDialectExtension(registry);
gpu::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index f22980036ffcfa1..5207559f3625095 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
@@ -86,6 +87,7 @@ inline void registerAllPasses() {
vector::registerVectorPasses();
arm_sme::registerArmSMEPasses();
arm_sve::registerArmSVEPasses();
+ emitc::registerEmitCPasses();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/EmitC/CMakeLists.txt b/mlir/lib/Dialect/EmitC/CMakeLists.txt
index f33061b2d87cffc..660deb21479d297 100644
--- a/mlir/lib/Dialect/EmitC/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/CMakeLists.txt
@@ -1 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(TransformOps)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index d06381b7ddad3dc..26599f62680319f 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -189,6 +189,82 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// ExpressionOp
+//===----------------------------------------------------------------------===//
+
+Operation *ExpressionOp::getRootOp() {
+ auto yieldOp = cast<YieldOp>(getRegion().front().getTerminator());
+ Value yieldedValue = yieldOp.getResult();
+ Operation *rootOp = yieldedValue.getDefiningOp();
+ assert(rootOp && "Yielded value not defined within expression");
+ return rootOp;
+}
+
+ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Parse results type.
+ Type expressionType;
+ if (parser.parseColonType(expressionType))
+ return failure();
+ result.addTypes(expressionType);
+
+ // Create the expression's body region.
+ result.regions.reserve(1);
+ Region *region = result.addRegion();
+
+ // Parse the region.
+ if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+
+ return success();
+}
+
+void ExpressionOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs());
+
+ p << " : " << getResult().getType() << ' ';
+
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+LogicalResult ExpressionOp::verify() {
+ Type resultType = getResult().getType();
+ Region ®ion = getRegion();
+
+ Block &body = region.front();
+
+ if (!body.mightHaveTerminator())
+ return emitOpError("must yield a value at termination");
+
+ auto yield = cast<YieldOp>(body.getTerminator());
+ Value yieldResult = yield.getResult();
+
+ if (!yieldResult)
+ return emitOpError("must yield a value at termination");
+
+ Type yieldType = yieldResult.getType();
+
+ if (resultType != yieldType)
+ return emitOpError("requires yielded type to match return type");
+
+ for (Operation &op : region.front().without_terminator()) {
+ if (!isCExpression(op))
+ return emitOpError("contains an unsupported operation");
+ if (op.getNumResults() != 1)
+ return emitOpError("requires exactly one result for each operation");
+ if (!op.getResults()[0].hasOneUse())
+ return emitOpError("requires exactly one use for each operation");
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@@ -530,6 +606,23 @@ LogicalResult emitc::VariableOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::YieldOp::verify() {
+ Value result = getResult();
+ Operation *containingOp = getOperation()->getParentOp();
+
+ if (result && containingOp->getNumResults() != 1)
+ return emitOpError() << "yields a value not returned by parent";
+
+ if (!result && containingOp->getNumResults() != 0)
+ return emitOpError() << "does not yield a value to be returned by parent";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/EmitC/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/EmitC/TransformOps/CMakeLists.txt
new file mode 100644
index 000000000000000..795735fb5ee5400
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/TransformOps/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIREmitCTransformOps
+ EmitCTransformOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/TransformOps
+
+ DEPENDS
+ MLIREmitCTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIREmitCDialect
+ MLIREmitCTransforms
+ MLIRTransformDialect
+)
diff --git a/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp b/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
new file mode 100644
index 000000000000000..ebad021556f7433
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
@@ -0,0 +1,114 @@
+//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "emitc-transforms"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ApplyExpressionPatternsOp
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyExpressionPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ emitc::populateExpressionPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// CreateExpressionOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::CreateExpressionOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getTarget());
+
+ SmallVector<Operation *> expressions;
+
+ for (Operation *op : payloadOps) {
+ if (!emitc::ExpressionOp::isCExpression(*op))
+ return emitDefiniteFailure("requires payload to be a C expression");
+ auto expressionOp = emitc::createExpression(op, rewriter);
+ expressions.push_back(expressionOp);
+ }
+
+ // Set results.
+ results.set(cast<OpResult>(getResult()), expressions);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::CreateExpressionOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(this->getOperation()->getOperands(), effects);
+ producesHandle(this->getOperation()->getResults(), effects);
+ modifiesPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// MatchExpressionRootOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MatchExpressionRootOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto payloadOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(payloadOps)) {
+ return emitDefiniteFailure("requires exactly one target handle");
+ }
+
+ SmallVector<Operation *> expressions;
+
+ auto matchFun = [&](Operation *op) {
+ if (emitc::ExpressionOp::isCExpression(*op))
+ expressions.push_back(op);
+ };
+
+ (*payloadOps.begin())->walk(matchFun);
+
+ // Set results.
+ results.set(cast<OpResult>(getResult()), expressions);
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class EmitCTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ EmitCTransformDialectExtension> {
+public:
+ EmitCTransformDialectExtension() {
+ declareGeneratedDialect<emitc::EmitCDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp.inc"
+
+void mlir::emitc::registerTransformDialectExtension(DialectRegistry ®istry) {
+ registry.addExtensions<EmitCTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..bfcc14523f137ae
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIREmitCTransforms
+ Transforms.cpp
+ FormExpressions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
+
+ DEPENDS
+ MLIREmitCTransformsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPass
+ MLIREmitCDialect
+ MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
new file mode 100644
index 000000000000000..ed35c914f06066a
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
@@ -0,0 +1,60 @@
+//===- FormExpressions.cpp - Form C-style expressions ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that forms EmitC operations modeling C operators
+// into C-style expressions using the emitc.expression op.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+
+namespace mlir {
+namespace emitc {
+#define GEN_PASS_DEF_FORMEXPRESSIONS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+} // namespace emitc
+} // namespace mlir
+
+using namespace mlir;
+using namespace emitc;
+
+namespace {
+struct FormExpressionsPass
+ : public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ MLIRContext *context = rootOp->getContext();
+
+ // Wrap each C operator op with an expression op.
+ OpBuilder builder(context);
+ auto matchFun = [&](Operation *op) {
+ if (emitc::ExpressionOp::isCExpression(*op))
+ createExpression(op, builder);
+ };
+ rootOp->walk(matchFun);
+
+ // Fold expressions where possible.
+ RewritePatternSet patterns(context);
+ populateExpressionPatterns(patterns);
+
+ if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<emitc::EmitCDialect>();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::emitc::createFormExpressionsPass() {
+ return std::make_unique<FormExpressionsPass>();
+}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
new file mode 100644
index 000000000000000..347e8a2b305da54
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -0,0 +1,117 @@
+//===- Transforms.cpp - Patterns and transforms for the EmitC dialect------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace emitc {
+
+ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
+ assert(ExpressionOp::isCExpression(*op) && "Expected a C expression");
+
+ // Create an expression yielding the value returned by op.
+ assert(op->getNumResults() == 1 && "Expected exactly one result");
+ Value result = op->getResults()[0];
+ Type resultType = result.getType();
+ Location loc = op->getLoc();
+
+ builder.setInsertionPointAfter(op);
+ auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType);
+
+ // Replace all op's uses with the new expression's result.
+ result.replaceAllUsesWith(expressionOp.getResult());
+
+ // Create an op to yield op's value.
+ Region ®ion = expressionOp.getRegion();
+ Block &block = region.emplaceBlock();
+ builder.setInsertionPointToEnd(&block);
+ auto yieldOp = builder.create<emitc::YieldOp>(loc, result);
+
+ // Move op into the new expression.
+ op->moveBefore(yieldOp);
+
+ return expressionOp;
+}
+
+} // namespace emitc
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::emitc;
+
+namespace {
+
+struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
+ using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+ PatternRewriter &rewriter) const override {
+ bool anythingFolded = false;
+ for (Operation &op : llvm::make_early_inc_range(
+ expressionOp.getBody()->without_terminator())) {
+ // Don't fold expressions whose result value has its address taken.
+ auto applyOp = dyn_cast<emitc::ApplyOp>(op);
+ if (applyOp && applyOp.getApplicableOperator() == "&")
+ continue;
+
+ for (Value operand : op.getOperands()) {
+ auto usedExpression =
+ dyn_cast_or_null<ExpressionOp>(operand.getDefiningOp());
+
+ if (!usedExpression)
+ continue;
+
+ // Don't fold expressions with multiple users: assume any
+ // re-materialization was done separately.
+ if (!usedExpression.getResult().hasOneUse())
+ continue;
+
+ // Don't fold expressions with side effects.
+ if (usedExpression.hasSideEffects())
+ continue;
+
+ // Fold the used expression into this expression by cloning all
+ // instructions in the used expression just before the operation using
+ // its value.
+ rewriter.setInsertionPoint(&op);
+ IRMapping mapper;
+ for (Operation &opToClone :
+ usedExpression.getBody()->without_terminator()) {
+ Operation *clone = rewriter.clone(opToClone, mapper);
+ mapper.map(&opToClone, clone);
+ }
+
+ auto usedYield =
+ cast<YieldOp>(usedExpression.getBody()->getTerminator());
+ Operation *expressionRoot = usedYield.getResult().getDefiningOp();
+ assert(expressionRoot && "Used expression has no root operation");
+ Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
+ assert(clonedExpressionRootOp &&
+ "Expected cloned expression root to be in mapper");
+ assert(clonedExpressionRootOp->getNumResults() == 1 &&
+ "Expected cloned root to have a single result");
+
+ Value clonedExpressionResult = clonedExpressionRootOp->getResults()[0];
+
+ usedExpression.getResult().replaceAllUsesWith(clonedExpressionResult);
+ rewriter.eraseOp(usedExpression);
+ anythingFolded = true;
+ }
+ }
+ return anythingFolded ? success() : failure();
+ }
+};
+
+} // namespace
+
+void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+ patterns.add<FoldExpressionOp>(patterns.getContext());
+}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 6c95eb3d20dacde..9765c4a0387fb98 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/IndentedOstream.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
@@ -65,6 +66,38 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
+/// Return the precedence of a operator as an integer, higher values
+/// imply higher precedence.
+static int getOperatorPrecedence(Operation *operation) {
+ return llvm::TypeSwitch<Operation *, int>(operation)
+ .Case<emitc::AddOp>([&](auto op) { return 11; })
+ .Case<emitc::ApplyOp>([&](auto op) { return 13; })
+ .Case<emitc::CastOp>([&](auto op) { return 13; })
+ .Case<emitc::CmpOp>([&](auto op) {
+ switch (op.getPredicate()) {
+ case emitc::CmpPredicate::eq:
+ case emitc::CmpPredicate::ne:
+ return 8;
+ case emitc::CmpPredicate::lt:
+ case emitc::CmpPredicate::le:
+ case emitc::CmpPredicate::gt:
+ case emitc::CmpPredicate::ge:
+ return 9;
+ case emitc::CmpPredicate::three_way:
+ return 10;
+ }
+ })
+ .Case<emitc::DivOp>([&](auto op) { return 12; })
+ .Case<emitc::MulOp>([&](auto op) { return 12; })
+ .Case<emitc::RemOp>([&](auto op) { return 12; })
+ .Case<emitc::SubOp>([&](auto op) { return 11; })
+ .Case<emitc::CallOp>([&](auto op) { return 14; })
+ .Default([&](Operation *) {
+ llvm_unreachable("Unsupported operator");
+ return 0;
+ });
+}
+
namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
@@ -115,6 +148,12 @@ struct CppEmitter {
/// Emits the operands of the operation. All operands are emitted in order.
LogicalResult emitOperands(Operation &op);
+ /// Emits value as an operands of an operation
+ LogicalResult emitOperand(Value value);
+
+ /// Emit an expression as a C expression.
+ LogicalResult emitExpression(ExpressionOp expressionOp);
+
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);
@@ -156,6 +195,21 @@ struct CppEmitter {
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
+ /// Get expression currently being emitted.
+ ExpressionOp getEmittedExpression() { return emittedExpression; }
+
+ /// Determine whether given value is part of the expression potentially being
+ /// emitted.
+ bool isPartOfCurrentExpression(Value value) {
+ if (!emittedExpression)
+ return false;
+ Operation *def = value.getDefiningOp();
+ if (!def)
+ return false;
+ auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
+ return operandExpression == emittedExpression;
+ };
+
private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
@@ -178,9 +232,50 @@ struct CppEmitter {
/// names of values in a scope.
std::stack<int64_t> valueInScopeCount;
std::stack<int64_t> labelInScopeCount;
+
+ /// State of the current expression being emitted.
+ ExpressionOp emittedExpression;
+ SmallVector<int> emittedExpressionPrecedence;
+
+ void pushExpressionPrecedence(int precedence) {
+ emittedExpressionPrecedence.push_back(precedence);
+ }
+ void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
+ static int lowestPrecedence() { return 0; }
+ int getExpressionPrecedence() {
+ if (emittedExpressionPrecedence.empty())
+ return lowestPrecedence();
+ return emittedExpressionPrecedence.back();
+ }
};
} // namespace
+/// Determine whether expression \p expressionOp should be emitted inline, i.e.
+/// as part of its user. This function recommends inlining of any expressions
+/// that can be inlined unless it is used by another expression, under the
+/// assumption that any expression fusion/re-materialization was taken care of
+/// by transformations run by the backend.
+static bool shouldBeInlined(ExpressionOp expressionOp) {
+ // Do not inline if expression is marked as such.
+ if (expressionOp.getDoNotInline())
+ return false;
+
+ // Do not inline expressions with side effects to prevent side-effect
+ // reordering.
+ if (expressionOp.hasSideEffects())
+ return false;
+
+ // Do not inline expressions with multiple uses.
+ Value result = expressionOp.getResult();
+ if (!result.hasOneUse())
+ return false;
+
+ // Do not inline expressions used by other expressions, as any desired
+ // expression folding was taken care of by transformations.
+ Operation *user = *result.getUsers().begin();
+ return !user->getParentOfType<ExpressionOp>();
+}
+
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
Attribute value) {
OpResult result = operation->getResult(0);
@@ -253,9 +348,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitVariableAssignment(result)))
return failure();
- emitter.ostream() << emitter.getOrCreateName(assignOp.getValue());
-
- return success();
+ return emitter.emitOperand(assignOp.getValue());
}
static LogicalResult printBinaryOperation(CppEmitter &emitter,
@@ -265,9 +358,14 @@ static LogicalResult printBinaryOperation(CppEmitter &emitter,
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
- os << emitter.getOrCreateName(operation->getOperand(0));
- os << " " << binaryOperator;
- os << " " << emitter.getOrCreateName(operation->getOperand(1));
+
+ if (failed(emitter.emitOperand(operation->getOperand(0))))
+ return failure();
+
+ os << " " << binaryOperator << " ";
+
+ if (failed(emitter.emitOperand(operation->getOperand(1))))
+ return failure();
return success();
}
@@ -483,9 +581,20 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
os << ") ";
- os << emitter.getOrCreateName(castOp.getOperand());
+ return emitter.emitOperand(castOp.getOperand());
+}
- return success();
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::ExpressionOp expressionOp) {
+ if (shouldBeInlined(expressionOp))
+ return success();
+
+ Operation &op = *expressionOp.getOperation();
+
+ if (failed(emitter.emitAssignPrefix(op)))
+ return failure();
+
+ return emitter.emitExpression(expressionOp);
}
static LogicalResult printOperation(CppEmitter &emitter,
@@ -505,6 +614,19 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
+ // Utility function to determine whether a value is an expression that will be
+ // inlined, and as such should be wrapped in parentheses in order to guarantee
+ // its precedence and associativity.
+ auto requiresParentheses = [&](Value value) {
+ Operation *def = value.getDefiningOp();
+ if (!def)
+ return false;
+ auto expressionOp = dyn_cast<ExpressionOp>(def);
+ if (!expressionOp)
+ return false;
+ return shouldBeInlined(expressionOp);
+ };
+
os << "for (";
if (failed(
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
@@ -512,15 +634,24 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
os << " ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << " = ";
- os << emitter.getOrCreateName(forOp.getLowerBound());
+ if (failed(emitter.emitOperand(forOp.getLowerBound())))
+ return failure();
os << "; ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << " < ";
- os << emitter.getOrCreateName(forOp.getUpperBound());
+ Value upperBound = forOp.getUpperBound();
+ bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
+ if (upperBoundRequiresParentheses)
+ os << "(";
+ if (failed(emitter.emitOperand(upperBound)))
+ return failure();
+ if (upperBoundRequiresParentheses)
+ os << ")";
os << "; ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << " += ";
- os << emitter.getOrCreateName(forOp.getStep());
+ if (failed(emitter.emitOperand(forOp.getStep())))
+ return failure();
os << ") {\n";
os.indent();
@@ -555,7 +686,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
};
os << "if (";
- if (failed(emitter.emitOperands(*ifOp.getOperation())))
+ if (failed(emitter.emitOperand(ifOp.getCondition())))
return failure();
os << ") {\n";
os.indent();
@@ -583,8 +714,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
case 0:
return success();
case 1:
- os << " " << emitter.getOrCreateName(returnOp.getOperand(0));
- return success(emitter.hasValueInScope(returnOp.getOperand(0)));
+ os << " ";
+ if (failed(emitter.emitOperand(returnOp.getOperand(0))))
+ return failure();
+ return success();
default:
os << " std::make_tuple(";
if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
@@ -637,7 +770,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
// regions.
WalkResult result =
functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
- if (isa<emitc::LiteralOp>(op))
+ if (isa<emitc::LiteralOp>(op) ||
+ isa<emitc::ExpressionOp>(op->getParentOp()) ||
+ (isa<emitc::ExpressionOp>(op) &&
+ shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return WalkResult::skip();
for (OpResult result : op->getResults()) {
if (failed(emitter.emitVariableDeclaration(
@@ -839,15 +975,70 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return emitError(loc, "cannot emit attribute: ") << attr;
}
+LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
+ assert(emittedExpressionPrecedence.empty() &&
+ "Expected precedence stack to be empty");
+ Operation *rootOp = expressionOp.getRootOp();
+
+ emittedExpression = expressionOp;
+ pushExpressionPrecedence(getOperatorPrecedence(rootOp));
+
+ if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
+ return failure();
+
+ popExpressionPrecedence();
+ assert(emittedExpressionPrecedence.empty() &&
+ "Expected precedence stack to be empty");
+ emittedExpression = nullptr;
+
+ return success();
+}
+
+LogicalResult CppEmitter::emitOperand(Value value) {
+ if (isPartOfCurrentExpression(value)) {
+ Operation *def = value.getDefiningOp();
+ assert(def && "Expected operand to be defined by an operation");
+ int precedence = getOperatorPrecedence(def);
+ bool encloseInParenthesis = precedence < getExpressionPrecedence();
+ if (encloseInParenthesis) {
+ os << "(";
+ pushExpressionPrecedence(lowestPrecedence());
+ } else
+ pushExpressionPrecedence(precedence);
+
+ if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
+ return failure();
+
+ if (encloseInParenthesis)
+ os << ")";
+
+ popExpressionPrecedence();
+ return success();
+ }
+
+ auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
+ if (expressionOp && shouldBeInlined(expressionOp))
+ return emitExpression(expressionOp);
+
+ auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
+ if (!literalOp && !hasValueInScope(value))
+ return failure();
+ os << getOrCreateName(value);
+ return success();
+}
+
LogicalResult CppEmitter::emitOperands(Operation &op) {
- auto emitOperandName = [&](Value result) -> LogicalResult {
- auto literalDef = dyn_cast_if_present<LiteralOp>(result.getDefiningOp());
- if (!literalDef && !hasValueInScope(result))
- return op.emitOpError() << "operand value not in scope";
- os << getOrCreateName(result);
+ return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
+ // If an expression is being omitted, push lowest precedence as these
+ // operands are either wrapped by parenthesis.
+ if (getEmittedExpression())
+ pushExpressionPrecedence(lowestPrecedence());
+ if (failed(emitOperand(operand)))
+ return failure();
+ if (getEmittedExpression())
+ popExpressionPrecedence();
return success();
- };
- return interleaveCommaWithError(op.getOperands(), os, emitOperandName);
+ });
}
LogicalResult
@@ -900,6 +1091,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
}
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
+ // If op is being emitted as part of an expression, bail out.
+ if (getEmittedExpression())
+ return success();
+
switch (op.getNumResults()) {
case 0:
break;
@@ -950,8 +1145,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
- emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
- emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+ emitc::ExpressionOp, emitc::ForOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
+ emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
@@ -970,7 +1166,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (isa<emitc::LiteralOp>(op))
return success();
+ if (getEmittedExpression() ||
+ (isa<emitc::ExpressionOp>(op) &&
+ shouldBeInlined(cast<emitc::ExpressionOp>(op))))
+ return success();
+
os << (trailingSemicolon ? ";\n" : "\n");
+
return success();
}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 53d88adf4305ff8..516d1b69fbc88aa 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
// -----
func.func @test_misplaced_yield() {
- // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}}
+ // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.expression, emitc.if, emitc.for'}}
emitc.yield
return
}
@@ -224,3 +224,60 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
emitc.assign %arg1 : f32 to %v : i32
return
}
+
+// -----
+
+func.func @test_expression_no_yield() -> i32 {
+ // expected-error @+1 {{'emitc.expression' op must yield a value at termination}}
+ %r = emitc.expression : i32 {
+ %c7 = "emitc.constant"(){value = 7 : i32} : () -> i32
+ }
+ return %r : i32
+}
+
+// -----
+
+func.func @test_expression_illegal_op(%arg0 : i1) -> i32 {
+ // expected-error @+1 {{'emitc.expression' op contains an unsupported operation}}
+ %r = emitc.expression : i32 {
+ %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ emitc.yield %x : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+func.func @test_expression_no_use(%arg0: i32, %arg1: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
+ %r = emitc.expression : i32 {
+ %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ emitc.yield %a : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' op requires exactly one use for each operation}}
+ %r = emitc.expression : i32 {
+ %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.mul %arg1, %a : (i32, i32) -> i32
+ emitc.yield %a : i32
+ }
+ return %r : i32
+}
+
+// -----
+
+func.func @test_expression_multiple_results(%arg0: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' op requires exactly one result for each operation}}
+ %r = emitc.expression : i32 {
+ %a:2 = emitc.call "bar" (%arg0) : (i32) -> (i32, i32)
+ emitc.yield %a : i32
+ }
+ return %r : i32
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 6c8398680980466..90d9db2615ef510 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -128,6 +128,23 @@ func.func @test_assign(%arg1: f32) {
return
}
+func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4: f32) -> i32 {
+ %c7 = "emitc.constant"() {value = 7 : i32} : () -> i32
+ %q = emitc.expression : i32 {
+ %a = emitc.rem %arg1, %c7 : (i32, i32) -> i32
+ emitc.yield %a : i32
+ }
+ %r = emitc.expression {do_not_inline} : i32 {
+ %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.call "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32)
+ %c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32
+ %d = emitc.cast %c : f32 to i32
+ %e = emitc.sub %b, %d : (i32, i32) -> i32
+ emitc.yield %e : i32
+ }
+ return %r : i32
+}
+
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
emitc.for %i0 = %arg0 to %arg1 step %arg2 {
%0 = emitc.call "func_const"(%i0) : (index) -> i32
diff --git a/mlir/test/Dialect/EmitC/transform-ops.mlir b/mlir/test/Dialect/EmitC/transform-ops.mlir
new file mode 100644
index 000000000000000..d3c77e1147cfe05
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/transform-ops.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck --check-prefix=ROUNDTRIP %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --form-expressions --verify-diagnostics --split-input-file | FileCheck %s
+
+// ROUNDTRIP-LABEL: transform.sequence failures(propagate) {
+// ROUNDTRIP-NEXT: ^bb0(%[[VAL_0:.*]]: !transform.any_op):
+// ROUNDTRIP-NEXT: %[[VAL_1:.*]] = transform.expression.match.root in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+// ROUNDTRIP-NEXT: %[[VAL_2:.*]] = transform.expression.create %[[VAL_1]] : (!transform.any_op) -> !transform.any_op
+// ROUNDTRIP-NEXT: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
+// ROUNDTRIP-NEXT: apply_patterns to %[[VAL_3]] {
+// ROUNDTRIP-NEXT: transform.apply_patterns.emitc.expressions
+// ROUNDTRIP-NEXT: } : !transform.any_op
+// ROUNDTRIP-NEXT: }
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+ %0 = transform.expression.match.root in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.expression.create %0 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %2 {
+ transform.apply_patterns.emitc.expressions
+ } : !transform.any_op
+}
+
+// CHECK-LABEL: func.func @single_expression(
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 {
+// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 42 : i32}> : () -> i32
+// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_4]] : (i32, i32) -> i32
+// CHECK: %[[VAL_7:.*]] = emitc.sub %[[VAL_6]], %[[VAL_2]] : (i32, i32) -> i32
+// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_3]] : (i32, i32) -> i1
+// CHECK: emitc.yield %[[VAL_8]] : i1
+// CHECK: }
+// CHECK: return %[[VAL_5]] : i1
+// CHECK: }
+
+func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 {
+ %c42 = "emitc.constant"(){value = 42 : i32} : () -> i32
+ %a = emitc.mul %arg0, %c42 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg2 : (i32, i32) -> i32
+ %c = emitc.cmp lt, %b, %arg3 :(i32, i32) -> i1
+ return %c : i1
+}
+
+// CHECK-LABEL: func.func @multiple_expressions(
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
+// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_6:.*]] = emitc.sub %[[VAL_5]], %[[VAL_2]] : (i32, i32) -> i32
+// CHECK: emitc.yield %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_8:.*]] = emitc.add %[[VAL_1]], %[[VAL_3]] : (i32, i32) -> i32
+// CHECK: %[[VAL_9:.*]] = emitc.div %[[VAL_8]], %[[VAL_2]] : (i32, i32) -> i32
+// CHECK: emitc.yield %[[VAL_9]] : i32
+// CHECK: }
+// CHECK: return %[[VAL_4]], %[[VAL_7]] : i32, i32
+// CHECK: }
+
+func.func @multiple_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> (i32, i32) {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg2 : (i32, i32) -> i32
+ %c = emitc.add %arg1, %arg3 : (i32, i32) -> i32
+ %d = emitc.div %c, %arg2 : (i32, i32) -> i32
+ return %b, %d : i32, i32
+}
+
+// CHECK-LABEL: func.func @expression_with_call(
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i1 {
+// CHECK: %[[VAL_4:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_5:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_6:.*]] = emitc.call "foo"(%[[VAL_5]], %[[VAL_2]]) : (i32, i32) -> i32
+// CHECK: emitc.yield %[[VAL_6]] : i32
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_4]], %[[VAL_1]] : (i32, i32) -> i1
+// CHECK: emitc.yield %[[VAL_8]] : i1
+// CHECK: }
+// CHECK: return %[[VAL_7]] : i1
+// CHECK: }
+
+func.func @expression_with_call(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.call "foo" (%a, %arg2) : (i32, i32) -> (i32)
+ %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+ return %c : i1
+}
+
+// CHECK-LABEL: func.func @expression_with_dereference(
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr<i32>) -> i1 {
+// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_4:.*]] = emitc.apply "*"(%[[VAL_2]]) : (!emitc.ptr<i32>) -> i32
+// CHECK: emitc.yield %[[VAL_4]] : i32
+// CHECK: }
+// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_6:.*]] = emitc.mul %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_7:.*]] = emitc.cmp lt, %[[VAL_6]], %[[VAL_3]] : (i32, i32) -> i1
+// CHECK: emitc.yield %[[VAL_7]] : i1
+// CHECK: }
+// CHECK: return %[[VAL_5]] : i1
+// CHECK: }
+
+func.func @expression_with_dereference(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.apply "*"(%arg2) : (!emitc.ptr<i32>) -> (i32)
+ %c = emitc.cmp lt, %a, %b :(i32, i32) -> i1
+ return %c : i1
+}
+
+// CHECK-LABEL: func.func @expression_with_address_taken(
+// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: !emitc.ptr<i32>) -> i1 {
+// CHECK: %[[VAL_3:.*]] = emitc.expression : i32 {
+// CHECK: %[[VAL_4:.*]] = emitc.rem %[[VAL_0]], %[[VAL_1]] : (i32, i32) -> i32
+// CHECK: emitc.yield %[[VAL_4]] : i32
+// CHECK: }
+// CHECK: %[[VAL_5:.*]] = emitc.expression : i1 {
+// CHECK: %[[VAL_6:.*]] = emitc.apply "&"(%[[VAL_3]]) : (i32) -> !emitc.ptr<i32>
+// CHECK: %[[VAL_7:.*]] = emitc.add %[[VAL_6]], %[[VAL_1]] : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
+// CHECK: %[[VAL_8:.*]] = emitc.cmp lt, %[[VAL_7]], %[[VAL_2]] : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
+// CHECK: emitc.yield %[[VAL_8]] : i1
+// CHECK: }
+// CHECK: return %[[VAL_5]] : i1
+// CHECK: }
+
+func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
+ %a = emitc.rem %arg0, %arg1 : (i32, i32) -> (i32)
+ %b = emitc.apply "&"(%a) : (i32) -> !emitc.ptr<i32>
+ %c = emitc.add %b, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
+ %d = emitc.cmp lt, %c, %arg2 :(!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
+ return %d : i1
+}
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
new file mode 100644
index 000000000000000..16465368520411b
--- /dev/null
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -0,0 +1,212 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
+
+// CPP-DEFAULT: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
+// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
+// CPP-DEFAULT-NEXT: if ([[VAL_5]]) {
+// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: } else {
+// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: }
+// CPP-DEFAULT-NEXT: return [[VAL_6]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t single_use(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * M_PI, [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: if ([[VAL_5]]) {
+// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: } else {
+// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: }
+// CPP-DECLTOP-NEXT: return [[VAL_6]];
+// CPP-DECLTOP-NEXT: }
+
+func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
+ %p0 = emitc.literal "M_PI" : i32
+ %e = emitc.expression : i1 {
+ %a = emitc.mul %arg0, %p0 : (i32, i32) -> i32
+ %b = emitc.call "bar" (%a, %arg2) : (i32, i32) -> (i32)
+ %c = emitc.sub %b, %arg3 : (i32, i32) -> i32
+ %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1
+ emitc.yield %d : i1
+ }
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32
+ emitc.if %e {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ } else {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ }
+ return %v : i32
+}
+
+// CPP-DEFAULT: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]];
+// CPP-DEFAULT-NEXT: return [[VAL_4]];
+// CPP-DEFAULT-NEXT:}
+
+// CPP-DECLTOP: int32_t do_not_inline(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_4]] = ([[VAL_1]] + [[VAL_2]]) * [[VAL_3]];
+// CPP-DECLTOP-NEXT: return [[VAL_4]];
+// CPP-DECLTOP-NEXT:}
+
+func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
+ %e = emitc.expression {do_not_inline} : i32 {
+ %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.mul %a, %arg2 : (i32, i32) -> i32
+ emitc.yield %b : i32
+ }
+ return %e : i32
+}
+
+// CPP-DEFAULT: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: float paranthesis_for_low_precedence(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: return (float) ([[VAL_1]] + [[VAL_2]] * [[VAL_3]]);
+// CPP-DECLTOP-NEXT: }
+
+func.func @paranthesis_for_low_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -> f32 {
+ %e = emitc.expression : f32 {
+ %a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.mul %a, %arg2 : (i32, i32) -> i32
+ %d = emitc.cast %b : i32 to f32
+ emitc.yield %d : f32
+ }
+ return %e : f32
+}
+
+// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
+// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];
+// CPP-DEFAULT-NEXT: if ([[VAL_5]]) {
+// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: } else {
+// CPP-DEFAULT-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: }
+// CPP-DEFAULT-NEXT: bool [[VAL_7:v[0-9]+]];
+// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_5]];
+// CPP-DEFAULT-NEXT: return [[VAL_6]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: bool [[VAL_5:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]];
+// CPP-DECLTOP-NEXT: bool [[VAL_7:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_5]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: if ([[VAL_5]]) {
+// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: } else {
+// CPP-DECLTOP-NEXT: [[VAL_6]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: }
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_5]];
+// CPP-DECLTOP-NEXT: return [[VAL_6]];
+// CPP-DECLTOP-NEXT: }
+
+func.func @multiple_uses(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
+ %e = emitc.expression : i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.call "bar" (%a, %arg2) : (i32, i32) -> (i32)
+ %c = emitc.sub %b, %arg3 : (i32, i32) -> i32
+ %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1
+ emitc.yield %d : i1
+ }
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32
+ emitc.if %e {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ } else {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ }
+ %q = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i1
+ emitc.assign %e : i1 to %q : i1
+ return %v : i32
+}
+
+// CPP-DEFAULT: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_5:v[0-9]+]] = [[VAL_3]] % [[VAL_4]];
+// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
+// CPP-DEFAULT-NEXT: int32_t [[VAL_7:v[0-9]+]];
+// CPP-DEFAULT-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: } else {
+// CPP-DEFAULT-NEXT: [[VAL_7]] = [[VAL_1]];
+// CPP-DEFAULT-NEXT: }
+// CPP-DEFAULT-NEXT: return [[VAL_7]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: int32_t different_expressions(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_5:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAL_6:v[0-9]+]];
+// CPP-DECLTOP-NEXT: int32_t [[VAL_7:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_5]] = [[VAL_3]] % [[VAL_4]];
+// CPP-DECLTOP-NEXT: [[VAL_6]] = bar([[VAL_5]], [[VAL_1]] * [[VAL_2]]);
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: if ([[VAL_6]] - [[VAL_4]] < [[VAL_2]]) {
+// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: } else {
+// CPP-DECLTOP-NEXT: [[VAL_7]] = [[VAL_1]];
+// CPP-DECLTOP-NEXT: }
+// CPP-DECLTOP-NEXT: return [[VAL_7]];
+// CPP-DECLTOP-NEXT: }
+
+func.func @different_expressions(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
+ %e1 = emitc.expression : i32 {
+ %a = emitc.rem %arg2, %arg3 : (i32, i32) -> i32
+ emitc.yield %a : i32
+ }
+ %e2 = emitc.expression : i32 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.call "bar" (%e1, %a) : (i32, i32) -> (i32)
+ emitc.yield %b : i32
+ }
+ %e3 = emitc.expression : i1 {
+ %c = emitc.sub %e2, %arg3 : (i32, i32) -> i32
+ %d = emitc.cmp lt, %c, %arg1 :(i32, i32) -> i1
+ emitc.yield %d : i1
+ }
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> i32
+ emitc.if %e3 {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ } else {
+ emitc.assign %arg0 : i32 to %v : i32
+ emitc.yield
+ }
+ return %v : i32
+}
+
+// CPP-DEFAULT: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
+// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = [[VAL_1]] % [[VAL_2]];
+// CPP-DEFAULT-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DEFAULT-NEXT: }
+
+// CPP-DECLTOP: bool expression_with_address_taken(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t* [[VAL_3]]) {
+// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
+// CPP-DECLTOP-NEXT: [[VAL_4]] = [[VAL_1]] % [[VAL_2]];
+// CPP-DECLTOP-NEXT: return &[[VAL_4]] - [[VAL_2]] < [[VAL_3]];
+// CPP-DECLTOP-NEXT: }
+
+func.func @expression_with_address_taken(%arg0: i32, %arg1: i32, %arg2: !emitc.ptr<i32>) -> i1 {
+ %a = emitc.expression : i32 {
+ %b = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ emitc.yield %b : i32
+ }
+ %c = emitc.expression : i1 {
+ %d = emitc.apply "&"(%a) : (i32) -> !emitc.ptr<i32>
+ %e = emitc.sub %d, %arg1 : (!emitc.ptr<i32>, i32) -> !emitc.ptr<i32>
+ %f = emitc.cmp lt, %e, %arg2 : (!emitc.ptr<i32>, !emitc.ptr<i32>) -> i1
+ emitc.yield %f : i1
+ }
+ return %c : i1
+}
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index c02c8b1ac33e371..c8f5206a083323a 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -2,20 +2,32 @@
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
- emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+ %lb = emitc.expression : index {
+ %a = emitc.add %arg0, %arg1 : (index, index) -> index
+ emitc.yield %a : index
+ }
+ %ub = emitc.expression : index {
+ %a = emitc.mul %arg1, %arg2 : (index, index) -> index
+ emitc.yield %a : index
+ }
+ %step = emitc.expression : index {
+ %a = emitc.div %arg0, %arg2 : (index, index) -> index
+ emitc.yield %a : index
+ }
+ emitc.for %i0 = %lb to %ub step %step {
%0 = emitc.call "f"() : () -> i32
}
return
}
-// CPP-DEFAULT: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) {
-// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
+// CPP-DEFAULT: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
+// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
// CPP-DEFAULT-NEXT: int32_t [[V4:[^ ]*]] = f();
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
-// CPP-DECLTOP: void test_for(size_t [[START:[^ ]*]], size_t [[STOP:[^ ]*]], size_t [[STEP:[^ ]*]]) {
+// CPP-DECLTOP: void test_for(size_t [[V1:[^ ]*]], size_t [[V2:[^ ]*]], size_t [[V3:[^ ]*]]) {
// CPP-DECLTOP-NEXT: int32_t [[V4:[^ ]*]];
-// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
+// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[V1]] + [[V2]]; [[ITER]] < ([[V2]] * [[V3]]); [[ITER]] += [[V1]] / [[V3]]) {
// CPP-DECLTOP-NEXT: [[V4]] = f();
// CPP-DECLTOP-NEXT: }
// CPP-DECLTOP-NEXT: return;
>From dd640a1e08a72792abfc43f0c7d0461ca9ff54e9 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Wed, 8 Nov 2023 11:26:57 +0200
Subject: [PATCH 02/21] Update
mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
Co-authored-by: Marius Brehler <marius.brehler at iml.fraunhofer.de>
---
.../mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
index 1e442af2e4c1edb..570b0defdfcd824 100644
--- a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
+++ b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.td
@@ -1,4 +1,4 @@
-//===- EmitCTransformOps.td - EmitC (loop) transformation ops --*- tablegen -*-===//
+//===- EmitCTransformOps.td - EmitC transformation ops -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From fdd3577ff2aa796271e0f38083cfa6d41b0584f5 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Wed, 8 Nov 2023 11:27:35 +0200
Subject: [PATCH 03/21] Update
mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
Co-authored-by: Marius Brehler <marius.brehler at iml.fraunhofer.de>
---
.../mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
index 5b31080c70d9fcb..3da7d61d7ded478 100644
--- a/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
+++ b/mlir/include/mlir/Dialect/EmitC/TransformOps/EmitCTransformOps.h
@@ -1,5 +1,4 @@
-//===- EmitCTransformOps.h - EmitC transformation ops ---------------*- C++
-//-*-===//
+//===- EmitCTransformOps.h - EmitC transformation ops -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From 80babdd8eff8f484676be0fbe3ad719a546d28a2 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Wed, 8 Nov 2023 11:38:47 +0200
Subject: [PATCH 04/21] Fix clang format
---
mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
index ed35c914f06066a..45ac0f9682c3787 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
@@ -11,10 +11,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace emitc {
@@ -28,7 +28,7 @@ using namespace emitc;
namespace {
struct FormExpressionsPass
- : public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
+ : public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
void runOnOperation() override {
Operation *rootOp = getOperation();
MLIRContext *context = rootOp->getContext();
>From 1013c145b1025597f9863c6c5b7a9041eca29131 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 09:03:35 +0200
Subject: [PATCH 05/21] Update mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Use standard assembly format (1/2)
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 4d522565c32826d..36392580ca59131 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -288,11 +288,11 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
}];
let arguments = (ins UnitAttr:$do_not_inline);
- let results = (outs AnyType);
+ let results = (outs AnyType:$result);
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "attr-dict `:` type($result) $region";
let extraClassDeclaration = [{
static bool isCExpression(Operation &op) {
>From 51016b2c2ff83bbb2f4bcf0e79b5de29559f9787 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 09:04:40 +0200
Subject: [PATCH 06/21] Update mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Use standard assembly format (2/2)
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 31 -----------------------------
1 file changed, 31 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 26599f62680319f..2266fb2ff2a9503 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -201,37 +201,6 @@ Operation *ExpressionOp::getRootOp() {
return rootOp;
}
-ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the optional attribute list.
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- // Parse results type.
- Type expressionType;
- if (parser.parseColonType(expressionType))
- return failure();
- result.addTypes(expressionType);
-
- // Create the expression's body region.
- result.regions.reserve(1);
- Region *region = result.addRegion();
-
- // Parse the region.
- if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
- return failure();
-
- return success();
-}
-
-void ExpressionOp::print(OpAsmPrinter &p) {
- p.printOptionalAttrDict((*this)->getAttrs());
-
- p << " : " << getResult().getType() << ' ';
-
- p.printRegion(getRegion(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/true);
-}
LogicalResult ExpressionOp::verify() {
Type resultType = getResult().getType();
>From 640083785b779d9c9a3ab927355f4fda523f4b9c Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 09:31:39 +0200
Subject: [PATCH 07/21] Update mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Replace loop with any_of
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 36392580ca59131..43f5e34d8683cda 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -311,12 +311,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
return true;
// Any operator using variables has a side effect of reading memory mutable by
// emitc::assign ops.
- for (Value operand : op.getOperands()) {
- Operation *def = operand.getDefiningOp();
- if (def && isa<emitc::VariableOp>(def))
- return true;
- }
- return false;
+ return llvm::any_of(op.getOperands(),
+ [](Value operand) { Operation *def = operand.getDefiningOp();
+ return def && isa<emitc::VariableOp>(def); });
}
bool hasSideEffects() {
return llvm::any_of(getRegion().front().without_terminator(),
>From dd6651d97c3eafc91550929bca02abc8b02229cc Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 10:07:56 +0200
Subject: [PATCH 08/21] Update mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 347e8a2b305da54..36649b40f7006e4 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -99,7 +99,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
assert(clonedExpressionRootOp->getNumResults() == 1 &&
"Expected cloned root to have a single result");
- Value clonedExpressionResult = clonedExpressionRootOp->getResults()[0];
+ Value clonedExpressionResult = clonedExpressionRootOp->getResult(0);
usedExpression.getResult().replaceAllUsesWith(clonedExpressionResult);
rewriter.eraseOp(usedExpression);
>From 7c4f81d5782426630fd661c7796ac081abf12884 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 10:08:05 +0200
Subject: [PATCH 09/21] Update mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 2266fb2ff2a9503..4a40dc3ba244a53 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -227,7 +227,7 @@ LogicalResult ExpressionOp::verify() {
return emitOpError("contains an unsupported operation");
if (op.getNumResults() != 1)
return emitOpError("requires exactly one result for each operation");
- if (!op.getResults()[0].hasOneUse())
+ if (!op.getResult(0).hasOneUse())
return emitOpError("requires exactly one use for each operation");
}
>From 67a8f404c3701b9a575ceee8f148257e2933d009 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Thu, 9 Nov 2023 10:08:13 +0200
Subject: [PATCH 10/21] Update mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 36649b40f7006e4..33524405c6666a2 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -20,7 +20,7 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
// Create an expression yielding the value returned by op.
assert(op->getNumResults() == 1 && "Expected exactly one result");
- Value result = op->getResults()[0];
+ Value result = op->getResult(0);
Type resultType = result.getType();
Location loc = op->getLoc();
>From bdb0580bc1c508f5ca94d64308ccd31c9fe90727 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Thu, 9 Nov 2023 10:51:44 +0200
Subject: [PATCH 11/21] Clean up
---
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4a40dc3ba244a53..ed23cef4e394435 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -201,7 +201,6 @@ Operation *ExpressionOp::getRootOp() {
return rootOp;
}
-
LogicalResult ExpressionOp::verify() {
Type resultType = getResult().getType();
Region ®ion = getRegion();
>From 403b6c6b96088cb7483d19ef64b26529cc0bc134 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 10:51:50 +0200
Subject: [PATCH 12/21] Update mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 33524405c6666a2..fa5d819049f5c37 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -89,10 +89,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
mapper.map(&opToClone, clone);
}
- auto usedYield =
- cast<YieldOp>(usedExpression.getBody()->getTerminator());
- Operation *expressionRoot = usedYield.getResult().getDefiningOp();
- assert(expressionRoot && "Used expression has no root operation");
+ Operation *expressionRoot = usedExpression.getRootOp();
Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
assert(clonedExpressionRootOp &&
"Expected cloned expression root to be in mapper");
>From f0d8b07ef451e247a1d493c12f4a422fb7071b0d Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 10:54:19 +0200
Subject: [PATCH 13/21] Update mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index fa5d819049f5c37..4e0392a5d6d9971 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -64,7 +64,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
for (Value operand : op.getOperands()) {
auto usedExpression =
- dyn_cast_or_null<ExpressionOp>(operand.getDefiningOp());
+ dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
if (!usedExpression)
continue;
>From e8eb5c443621d384676ae2bf23018686f611375f Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:02:24 +0200
Subject: [PATCH 14/21] Update mlir/lib/Target/Cpp/TranslateToCpp.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 9765c4a0387fb98..7add470dbc87624 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -618,10 +618,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
// inlined, and as such should be wrapped in parentheses in order to guarantee
// its precedence and associativity.
auto requiresParentheses = [&](Value value) {
- Operation *def = value.getDefiningOp();
- if (!def)
- return false;
- auto expressionOp = dyn_cast<ExpressionOp>(def);
+ auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
>From 952c0299bec2d76885a48e7afaacb2a042fddd33 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:02:42 +0200
Subject: [PATCH 15/21] Update
mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp b/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
index ebad021556f7433..7978051df5c7507 100644
--- a/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
+++ b/mlir/lib/Dialect/EmitC/TransformOps/EmitCTransformOps.cpp
@@ -1,4 +1,4 @@
-//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
+//===- EmitCTransformOps.cpp - Implementation of EmitC match ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From 3eadc8a42811e7e6e43285cfb1782c8654ff45a2 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:04:16 +0200
Subject: [PATCH 16/21] Update mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 4e0392a5d6d9971..593d774cac73bd3 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -1,4 +1,4 @@
-//===- Transforms.cpp - Patterns and transforms for the EmitC dialect------===//
+//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
>From bea36debd2aef786c93552d91b7de47739e9ac05 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:41:38 +0200
Subject: [PATCH 17/21] Update mlir/test/Dialect/EmitC/ops.mlir
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/test/Dialect/EmitC/ops.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 90d9db2615ef510..244d2b1b834c9b4 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -134,7 +134,7 @@ func.func @test_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32, %arg4
%a = emitc.rem %arg1, %c7 : (i32, i32) -> i32
emitc.yield %a : i32
}
- %r = emitc.expression {do_not_inline} : i32 {
+ %r = emitc.expression noinline : i32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.call "bar" (%a, %arg2, %q) : (i32, i32, i32) -> (i32)
%c = emitc.mul %arg3, %arg4 : (f32, f32) -> f32
>From 2188b037b9dc9e7a4e62921da825319e5817fdda Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:41:53 +0200
Subject: [PATCH 18/21] Update mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 43f5e34d8683cda..f65990cff03d47f 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -292,7 +292,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
- let assemblyFormat = "attr-dict `:` type($result) $region";
+ let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
let extraClassDeclaration = [{
static bool isCExpression(Operation &op) {
>From 546300e098f4831d663bf691b6d444900e82b6ef Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:42:07 +0200
Subject: [PATCH 19/21] Update mlir/test/Target/Cpp/expressions.mlir
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/test/Target/Cpp/expressions.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Target/Cpp/expressions.mlir b/mlir/test/Target/Cpp/expressions.mlir
index 16465368520411b..ace1b7a11fae508 100644
--- a/mlir/test/Target/Cpp/expressions.mlir
+++ b/mlir/test/Target/Cpp/expressions.mlir
@@ -57,7 +57,7 @@ func.func @single_use(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> i32 {
// CPP-DECLTOP-NEXT:}
func.func @do_not_inline(%arg0: i32, %arg1: i32, %arg2 : i32) -> i32 {
- %e = emitc.expression {do_not_inline} : i32 {
+ %e = emitc.expression noinline : i32 {
%a = emitc.add %arg0, %arg1 : (i32, i32) -> i32
%b = emitc.mul %a, %arg2 : (i32, i32) -> i32
emitc.yield %b : i32
>From 4d877329c039caa2fbcb28eee0bc32528b84ca19 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:52:27 +0200
Subject: [PATCH 20/21] Update mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index f65990cff03d47f..72e12d406899ab5 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -301,7 +301,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
emitc::SubOp>(op);
}
static bool hasSideEffects(Operation &op) {
- assert(isCExpression(op) && "Expected a C operator");
+ assert(isCExpression(op) && "Expected a C expression");
// Conservatively assume calls to read and write memory.
if (isa<emitc::CallOp>(op))
return true;
>From a9411f7f2ba47d3385a31d0053509103d01319b1 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <aniragil at gmail.com>
Date: Fri, 10 Nov 2023 11:54:57 +0200
Subject: [PATCH 21/21] Update mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Co-authored-by: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 72e12d406899ab5..a9c60ffcd59c901 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -309,7 +309,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
auto applyOp = dyn_cast<emitc::ApplyOp>(op);
if (applyOp && applyOp.getApplicableOperator() == "*")
return true;
- // Any operator using variables has a side effect of reading memory mutable by
+ // Any operation using variables has a side effect of reading memory mutable by
// emitc::assign ops.
return llvm::any_of(op.getOperands(),
[](Value operand) { Operation *def = operand.getDefiningOp();
More information about the Mlir-commits
mailing list