[Mlir-commits] [mlir] [MLIR] EmitC: Add subscript operator (PR #84783)

Matthias Gehre llvmlistbot at llvm.org
Mon Mar 11 09:20:00 PDT 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/84783

Introduces a SubscriptOp that allows to write IR like
```
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
  emitc.assign %0 : f32 to %1 : f32
  return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```

To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps

The emitter prints emitc.subscript in a delayed fashing to allow it being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
```
it will not emit any text, but record in the `valueMapper` that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when that result is then used (here in `emitc.assign`), that name is inserted into the text.

>From 3bc50bf76d86aebc13a5e9d77cd4ad726b5772eb Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 11 Mar 2024 16:52:57 +0100
Subject: [PATCH] [MLIR] EmitC: Add subscript operator

Introduces a SubscriptOp that allows to write IR like
```
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
  emitc.assign %0 : f32 to %1 : f32
  return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```

To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps

The emitter prints emitc.subscript in a delayed fashing to allow
it being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
```
it will not emit any text, but record in the `valueMapper`
that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`.
Only when that result is then used (here in `emitc.assign`),
that name is inserted into the text.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 32 ++++++++++++++++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 19 +++++++++-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      | 41 +++++++++++++++++----
 mlir/test/Dialect/EmitC/invalid_ops.mlir    |  8 ++++
 mlir/test/Target/Cpp/subscript.mlir         | 12 ++++++
 5 files changed, 103 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/subscript.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..bcdd001528c46d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1125,4 +1125,36 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
+def EmitC_SubscriptOp : EmitC_Op<"subscript",
+  [TypesMatchWith<"result type matches element type of 'array'",
+                  "array", "result",
+                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
+  let summary = "Array subscript operation";
+  let description = [{
+    With the `subscript` operation the subscript operator `[]` can be applied
+    to variables or arguments of array type.
+
+    Example:
+
+    ```mlir
+    %i = index.constant 1
+    %j = index.constant 7
+    %0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
+    ```
+  }];
+  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
+                       Variadic<Index>:$indices);
+  let results = (outs AnyType:$result);
+
+  let builders = [
+    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    }]>
+  ];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
+}
+
+
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 9426bbbe2370f0..e401a83bcb42e6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
 LogicalResult emitc::AssignOp::verify() {
   Value variable = getVar();
   Operation *variableDef = variable.getDefiningOp();
-  if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+  if (!variableDef ||
+      !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
     return emitOpError() << "requires first operand (" << variable
-                         << ") to be a Variable";
+                         << ") to be a Variable or subscript";
 
   Value value = getValue();
   if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::SubscriptOp::verify() {
+  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
+    return emitOpError() << "requires number of indices ("
+                         << getIndices().size()
+                         << ") to match the rank of the array type ("
+                         << getArray().getType().getRank() << ")";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..6e477a34fc4ba9 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -171,6 +171,9 @@ struct CppEmitter {
   /// Return the existing or a new name for a Value.
   StringRef getOrCreateName(Value val);
 
+  // Returns the textual representation of a subscript operation.
+  std::string getSubscriptName(emitc::SubscriptOp op);
+
   /// Return the existing or a new label of a Block.
   StringRef getOrCreateName(Block &block);
 
@@ -340,8 +343,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::AssignOp assignOp) {
-  auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
-  OpResult result = variableOp->getResult(0);
+  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
 
   if (failed(emitter.emitVariableAssignment(result)))
     return failure();
@@ -349,6 +351,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return emitter.emitOperand(assignOp.getValue());
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::SubscriptOp subscriptOp) {
+  // Add name to cache so that `hasValueInScope` works.
+  emitter.getOrCreateName(subscriptOp.getResult());
+  return success();
+}
+
 static LogicalResult printBinaryOperation(CppEmitter &emitter,
                                           Operation *operation,
                                           StringRef binaryOperator) {
@@ -1067,12 +1076,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
   labelInScopeCount.push(0);
 }
 
+std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
+  std::string out;
+  llvm::raw_string_ostream ss(out);
+  ss << getOrCreateName(op.getArray());
+  for (auto index : op.getIndices()) {
+    ss << "[" << getOrCreateName(index) << "]";
+  }
+  return out;
+}
+
 /// Return the existing or a new name for a Value.
 StringRef CppEmitter::getOrCreateName(Value val) {
   if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
     return literal.getValue();
-  if (!valueMapper.count(val))
-    valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+  if (!valueMapper.count(val)) {
+    if (auto subscript =
+            dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
+      valueMapper.insert(val, getSubscriptName(subscript));
+    } else {
+      valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+    }
+  }
   return *valueMapper.begin(val);
 }
 
@@ -1312,6 +1337,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
 
 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
                                                   bool trailingSemicolon) {
+  if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+    return success();
   if (hasValueInScope(result)) {
     return result.getDefiningOp()->emitError(
         "result variable for the operation already declared");
@@ -1387,8 +1414,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
                 emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
                 emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
-                emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
-                emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
+                emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1401,7 +1428,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   if (failed(status))
     return failure();
 
-  if (isa<emitc::LiteralOp>(op))
+  if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
     return success();
 
   if (getEmittedExpression() ||
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 58b3a11ed93e15..cc718e190484a8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
   %0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
   return
 }
+
+// -----
+
+func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
+  return
+}
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
new file mode 100644
index 00000000000000..0f9bc515b48dc2
--- /dev/null
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];



More information about the Mlir-commits mailing list