[Mlir-commits] [mlir] [MLIR] EmitC: Add subscript operator (PR #84783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 11 09:20:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
@llvm/pr-subscribers-mlir
Author: Matthias Gehre (mgehre-amd)
<details>
<summary>Changes</summary>
Introduces a SubscriptOp that allows to write IR like
```
func.func @<!-- -->load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
emitc.assign %0 : f32 to %1 : f32
return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```
To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps
The emitter prints emitc.subscript in a delayed fashing to allow it being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
```
it will not emit any text, but record in the `valueMapper` that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when that result is then used (here in `emitc.assign`), that name is inserted into the text.
---
Full diff: https://github.com/llvm/llvm-project/pull/84783.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+32)
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+17-2)
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+34-7)
- (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+8)
- (added) mlir/test/Target/Cpp/subscript.mlir (+12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..bcdd001528c46d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1125,4 +1125,36 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}
+def EmitC_SubscriptOp : EmitC_Op<"subscript",
+ [TypesMatchWith<"result type matches element type of 'array'",
+ "array", "result",
+ "::llvm::cast<ArrayType>($_self).getElementType()">]> {
+ let summary = "Array subscript operation";
+ let description = [{
+ With the `subscript` operation the subscript operator `[]` can be applied
+ to variables or arguments of array type.
+
+ Example:
+
+ ```mlir
+ %i = index.constant 1
+ %j = index.constant 7
+ %0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
+ ```
+ }];
+ let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
+ Variadic<Index>:$indices);
+ let results = (outs AnyType:$result);
+
+ let builders = [
+ OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
+ build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+ }]>
+ ];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
+}
+
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 9426bbbe2370f0..e401a83bcb42e6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
LogicalResult emitc::AssignOp::verify() {
Value variable = getVar();
Operation *variableDef = variable.getDefiningOp();
- if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+ if (!variableDef ||
+ !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
return emitOpError() << "requires first operand (" << variable
- << ") to be a Variable";
+ << ") to be a Variable or subscript";
Value value = getValue();
if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::SubscriptOp::verify() {
+ if (getIndices().size() != (size_t)getArray().getType().getRank()) {
+ return emitOpError() << "requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << getArray().getType().getRank() << ")";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..6e477a34fc4ba9 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -171,6 +171,9 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);
+ // Returns the textual representation of a subscript operation.
+ std::string getSubscriptName(emitc::SubscriptOp op);
+
/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);
@@ -340,8 +343,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
- auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
- OpResult result = variableOp->getResult(0);
+ OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
if (failed(emitter.emitVariableAssignment(result)))
return failure();
@@ -349,6 +351,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::SubscriptOp subscriptOp) {
+ // Add name to cache so that `hasValueInScope` works.
+ emitter.getOrCreateName(subscriptOp.getResult());
+ return success();
+}
+
static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
@@ -1067,12 +1076,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
labelInScopeCount.push(0);
}
+std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << getOrCreateName(op.getArray());
+ for (auto index : op.getIndices()) {
+ ss << "[" << getOrCreateName(index) << "]";
+ }
+ return out;
+}
+
/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
- if (!valueMapper.count(val))
- valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+ if (!valueMapper.count(val)) {
+ if (auto subscript =
+ dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
+ valueMapper.insert(val, getSubscriptName(subscript));
+ } else {
+ valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+ }
+ }
return *valueMapper.begin(val);
}
@@ -1312,6 +1337,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
+ if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+ return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
@@ -1387,8 +1414,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
- emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
- emitc::VariableOp, emitc::VerbatimOp>(
+ emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
+ emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1401,7 +1428,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();
- if (isa<emitc::LiteralOp>(op))
+ if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
return success();
if (getEmittedExpression() ||
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 58b3a11ed93e15..cc718e190484a8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
%0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
return
}
+
+// -----
+
+func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+ %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
+ return
+}
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
new file mode 100644
index 00000000000000..0f9bc515b48dc2
--- /dev/null
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
+ %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
+ emitc.assign %0 : f32 to %1 : f32
+ return
+}
+// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
``````````
</details>
https://github.com/llvm/llvm-project/pull/84783
More information about the Mlir-commits
mailing list