[Mlir-commits] [mlir] [MLIR] EmitC: Add subscript operator (PR #84783)

Matthias Gehre llvmlistbot at llvm.org
Thu Mar 14 15:10:17 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/84783

>From 731eb8e8694d6f68c2fdde2f2fe5686cc481a0bc Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 11 Mar 2024 16:52:57 +0100
Subject: [PATCH 1/2] [MLIR] EmitC: Add subscript operator

Introduces a SubscriptOp that allows to write IR like
```
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
  emitc.assign %0 : f32 to %1 : f32
  return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```

To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps

The emitter prints emitc.subscript in a delayed fashing to allow
it being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
```
it will not emit any text, but record in the `valueMapper`
that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`.
Only when that result is then used (here in `emitc.assign`),
that name is inserted into the text.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 32 +++++++++++++++++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 19 ++++++++--
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      | 39 +++++++++++++++++----
 mlir/test/Dialect/EmitC/invalid_ops.mlir    |  8 +++++
 mlir/test/Target/Cpp/subscript.mlir         | 12 +++++++
 5 files changed, 102 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/subscript.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ec842f76628c08..fc00fe0c6dfe1e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,4 +1155,36 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
+def EmitC_SubscriptOp : EmitC_Op<"subscript",
+  [TypesMatchWith<"result type matches element type of 'array'",
+                  "array", "result",
+                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
+  let summary = "Array subscript operation";
+  let description = [{
+    With the `subscript` operation the subscript operator `[]` can be applied
+    to variables or arguments of array type.
+
+    Example:
+
+    ```mlir
+    %i = index.constant 1
+    %j = index.constant 7
+    %0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
+    ```
+  }];
+  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
+                       Variadic<Index>:$indices);
+  let results = (outs AnyType:$result);
+
+  let builders = [
+    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    }]>
+  ];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
+}
+
+
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 9426bbbe2370f0..e401a83bcb42e6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
 LogicalResult emitc::AssignOp::verify() {
   Value variable = getVar();
   Operation *variableDef = variable.getDefiningOp();
-  if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+  if (!variableDef ||
+      !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
     return emitOpError() << "requires first operand (" << variable
-                         << ") to be a Variable";
+                         << ") to be a Variable or subscript";
 
   Value value = getValue();
   if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::SubscriptOp::verify() {
+  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
+    return emitOpError() << "requires number of indices ("
+                         << getIndices().size()
+                         << ") to match the rank of the array type ("
+                         << getArray().getType().getRank() << ")";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 7cbb1e9265e174..f19ab031096e2f 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -172,6 +172,9 @@ struct CppEmitter {
   /// Return the existing or a new name for a Value.
   StringRef getOrCreateName(Value val);
 
+  // Returns the textual representation of a subscript operation.
+  std::string getSubscriptName(emitc::SubscriptOp op);
+
   /// Return the existing or a new label of a Block.
   StringRef getOrCreateName(Block &block);
 
@@ -341,8 +344,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::AssignOp assignOp) {
-  auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
-  OpResult result = variableOp->getResult(0);
+  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
 
   if (failed(emitter.emitVariableAssignment(result)))
     return failure();
@@ -350,6 +352,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return emitter.emitOperand(assignOp.getValue());
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::SubscriptOp subscriptOp) {
+  // Add name to cache so that `hasValueInScope` works.
+  emitter.getOrCreateName(subscriptOp.getResult());
+  return success();
+}
+
 static LogicalResult printBinaryOperation(CppEmitter &emitter,
                                           Operation *operation,
                                           StringRef binaryOperator) {
@@ -1091,12 +1100,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
   labelInScopeCount.push(0);
 }
 
+std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
+  std::string out;
+  llvm::raw_string_ostream ss(out);
+  ss << getOrCreateName(op.getArray());
+  for (auto index : op.getIndices()) {
+    ss << "[" << getOrCreateName(index) << "]";
+  }
+  return out;
+}
+
 /// Return the existing or a new name for a Value.
 StringRef CppEmitter::getOrCreateName(Value val) {
   if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
     return literal.getValue();
-  if (!valueMapper.count(val))
-    valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+  if (!valueMapper.count(val)) {
+    if (auto subscript =
+            dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
+      valueMapper.insert(val, getSubscriptName(subscript));
+    } else {
+      valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+    }
+  }
   return *valueMapper.begin(val);
 }
 
@@ -1336,6 +1361,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
 
 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
                                                   bool trailingSemicolon) {
+  if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+    return success();
   if (hasValueInScope(result)) {
     return result.getDefiningOp()->emitError(
         "result variable for the operation already declared");
@@ -1411,7 +1438,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
                 emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
                 emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
-                emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
                 emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
                 emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
@@ -1426,7 +1453,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   if (failed(status))
     return failure();
 
-  if (isa<emitc::LiteralOp>(op))
+  if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
     return success();
 
   if (getEmittedExpression() ||
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 58b3a11ed93e15..cc718e190484a8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
   %0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
   return
 }
+
+// -----
+
+func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
+  return
+}
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
new file mode 100644
index 00000000000000..0f9bc515b48dc2
--- /dev/null
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

>From 8d85dd63bda7331357304153928785eb46649062 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 14 Mar 2024 22:51:49 +0100
Subject: [PATCH 2/2] review comments

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td |  6 +++---
 mlir/test/Dialect/EmitC/invalid_ops.mlir    |  4 ++--
 mlir/test/Target/Cpp/subscript.mlir         | 24 +++++++++++++++++++--
 3 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index fc00fe0c6dfe1e..78bfd561171f50 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1169,11 +1169,11 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
     ```mlir
     %i = index.constant 1
     %j = index.constant 7
-    %0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
+    %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
     ```
   }];
   let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
-                       Variadic<Index>:$indices);
+                       Variadic<IntegerIndexOrOpaqueType>:$indices);
   let results = (outs AnyType:$result);
 
   let builders = [
@@ -1183,7 +1183,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
   ];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
+  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
 }
 
 
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index cc718e190484a8..6294c853d99931 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
 // -----
 
 func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
-  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
+  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
   emitc.assign %arg1 : f32 to %arg2 : f32
   return
 }
@@ -392,6 +392,6 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
 
 func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
   // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
-  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
+  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
   return
 }
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index 0f9bc515b48dc2..a6c82df9111a79 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -2,11 +2,31 @@
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
 
 func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
-  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
-  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
   emitc.assign %0 : f32 to %1 : f32
   return
 }
 // CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
 // CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
 // CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+
+emitc.func @func1(%arg0 : f32) {
+  emitc.return
+}
+
+emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
+                     %k: i8) {
+  %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
+  %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+
+  emitc.call @func1 (%0) : (f32) -> ()
+  emitc.call_opaque "func2" (%1) : (f32) -> ()
+  emitc.call_opaque "func3" (%0, %1) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+  emitc.return
+}
+// CHECK: void call_arg(float [[ARR1:[^ ]*]][4][8], int32_t [[I:[^ ]*]],
+// CHECK-SAME:          int16_t [[J:[^ ]*]], int8_t [[K:[^ ]*]])
+// CHECK-NEXT: func1([[ARR1]][[[I]]][[[J]]]);
+// CHECK-NEXT: func2([[ARR1]][[[J]]][[[K]]]);
+// CHECK-NEXT: func3([[ARR1]][[[J]]][[[K]]], [[ARR1]][[[I]]][[[J]]]);



More information about the Mlir-commits mailing list