[Mlir-commits] [mlir] 01a31ce - [MLIR] EmitC: Add subscript operator (#84783)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 15 03:08:39 PDT 2024
Author: Matthias Gehre
Date: 2024-03-15T11:08:34+01:00
New Revision: 01a31cee561efe90fbd1d33fa89f403dd8ff9012
URL: https://github.com/llvm/llvm-project/commit/01a31cee561efe90fbd1d33fa89f403dd8ff9012
DIFF: https://github.com/llvm/llvm-project/commit/01a31cee561efe90fbd1d33fa89f403dd8ff9012.diff
LOG: [MLIR] EmitC: Add subscript operator (#84783)
Introduces a SubscriptOp that allows to write IR like
```
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
emitc.assign %0 : f32 to %1 : f32
return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```
To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps
The emitter prints emitc.subscript in a delayed fashing to allow it
being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
```
it will not emit any text, but record in the `valueMapper` that the name
for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when
that result is then used (here in `emitc.assign`), that name is inserted
into the text.
Added:
mlir/test/Target/Cpp/subscript.mlir
Modified:
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
mlir/lib/Target/Cpp/TranslateToCpp.cpp
mlir/test/Dialect/EmitC/invalid_ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ec842f76628c08..78bfd561171f50 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,4 +1155,36 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}
+def EmitC_SubscriptOp : EmitC_Op<"subscript",
+ [TypesMatchWith<"result type matches element type of 'array'",
+ "array", "result",
+ "::llvm::cast<ArrayType>($_self).getElementType()">]> {
+ let summary = "Array subscript operation";
+ let description = [{
+ With the `subscript` operation the subscript operator `[]` can be applied
+ to variables or arguments of array type.
+
+ Example:
+
+ ```mlir
+ %i = index.constant 1
+ %j = index.constant 7
+ %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+ ```
+ }];
+ let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
+ Variadic<IntegerIndexOrOpaqueType>:$indices);
+ let results = (outs AnyType:$result);
+
+ let builders = [
+ OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
+ build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+ }]>
+ ];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+}
+
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 9426bbbe2370f0..e401a83bcb42e6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
LogicalResult emitc::AssignOp::verify() {
Value variable = getVar();
Operation *variableDef = variable.getDefiningOp();
- if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+ if (!variableDef ||
+ !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
return emitOpError() << "requires first operand (" << variable
- << ") to be a Variable";
+ << ") to be a Variable or subscript";
Value value = getValue();
if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::SubscriptOp::verify() {
+ if (getIndices().size() != (size_t)getArray().getType().getRank()) {
+ return emitOpError() << "requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << getArray().getType().getRank() << ")";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index da2e6e5891da63..bc49d7cd67126e 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -174,6 +174,9 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);
+ // Returns the textual representation of a subscript operation.
+ std::string getSubscriptName(emitc::SubscriptOp op);
+
/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);
@@ -343,8 +346,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
- auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
- OpResult result = variableOp->getResult(0);
+ OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
if (failed(emitter.emitVariableAssignment(result)))
return failure();
@@ -352,6 +354,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::SubscriptOp subscriptOp) {
+ // Add name to cache so that `hasValueInScope` works.
+ emitter.getOrCreateName(subscriptOp.getResult());
+ return success();
+}
+
static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
@@ -1093,12 +1102,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
labelInScopeCount.push(0);
}
+std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
+ std::string out;
+ llvm::raw_string_ostream ss(out);
+ ss << getOrCreateName(op.getArray());
+ for (auto index : op.getIndices()) {
+ ss << "[" << getOrCreateName(index) << "]";
+ }
+ return out;
+}
+
/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
- if (!valueMapper.count(val))
- valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+ if (!valueMapper.count(val)) {
+ if (auto subscript =
+ dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
+ valueMapper.insert(val, getSubscriptName(subscript));
+ } else {
+ valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+ }
+ }
return *valueMapper.begin(val);
}
@@ -1338,6 +1363,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
+ if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+ return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
@@ -1413,7 +1440,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+ emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
@@ -1428,7 +1455,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();
- if (isa<emitc::LiteralOp>(op))
+ if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
return success();
if (getEmittedExpression() ||
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 58b3a11ed93e15..6294c853d99931 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
// -----
func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
- // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
+ // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
emitc.assign %arg1 : f32 to %arg2 : f32
return
}
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
%0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
return
}
+
+// -----
+
+func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+ %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+ return
+}
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
new file mode 100644
index 00000000000000..a6c82df9111a79
--- /dev/null
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
+ %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+ emitc.assign %0 : f32 to %1 : f32
+ return
+}
+// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+
+emitc.func @func1(%arg0 : f32) {
+ emitc.return
+}
+
+emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
+ %k: i8) {
+ %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
+ %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+
+ emitc.call @func1 (%0) : (f32) -> ()
+ emitc.call_opaque "func2" (%1) : (f32) -> ()
+ emitc.call_opaque "func3" (%0, %1) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+ emitc.return
+}
+// CHECK: void call_arg(float [[ARR1:[^ ]*]][4][8], int32_t [[I:[^ ]*]],
+// CHECK-SAME: int16_t [[J:[^ ]*]], int8_t [[K:[^ ]*]])
+// CHECK-NEXT: func1([[ARR1]][[[I]]][[[J]]]);
+// CHECK-NEXT: func2([[ARR1]][[[J]]][[[K]]]);
+// CHECK-NEXT: func3([[ARR1]][[[J]]][[[K]]], [[ARR1]][[[I]]][[[J]]]);
More information about the Mlir-commits
mailing list