[Mlir-commits] [mlir] 01a31ce - [MLIR] EmitC: Add subscript operator (#84783)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 15 03:08:39 PDT 2024


Author: Matthias Gehre
Date: 2024-03-15T11:08:34+01:00
New Revision: 01a31cee561efe90fbd1d33fa89f403dd8ff9012

URL: https://github.com/llvm/llvm-project/commit/01a31cee561efe90fbd1d33fa89f403dd8ff9012
DIFF: https://github.com/llvm/llvm-project/commit/01a31cee561efe90fbd1d33fa89f403dd8ff9012.diff

LOG: [MLIR] EmitC: Add subscript operator (#84783)

Introduces a SubscriptOp that allows to write IR like
```
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
  emitc.assign %0 : f32 to %1 : f32
  return
}
```
which gets translated into the C++ code
```
v1[v2][v3] = v0[v1][v2];
```

To make this happen, this
- adds the SubscriptOp
- allows the subscript op as rhs of emitc.assign
- updates the emitter to print SubscriptOps

The emitter prints emitc.subscript in a delayed fashing to allow it
being used as lvalue.
I.e. while processing
```
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
```
it will not emit any text, but record in the `valueMapper` that the name
for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when
that result is then used (here in `emitc.assign`), that name is inserted
into the text.

Added: 
    mlir/test/Target/Cpp/subscript.mlir

Modified: 
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Target/Cpp/TranslateToCpp.cpp
    mlir/test/Dialect/EmitC/invalid_ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ec842f76628c08..78bfd561171f50 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,4 +1155,36 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
+def EmitC_SubscriptOp : EmitC_Op<"subscript",
+  [TypesMatchWith<"result type matches element type of 'array'",
+                  "array", "result",
+                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
+  let summary = "Array subscript operation";
+  let description = [{
+    With the `subscript` operation the subscript operator `[]` can be applied
+    to variables or arguments of array type.
+
+    Example:
+
+    ```mlir
+    %i = index.constant 1
+    %j = index.constant 7
+    %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+    ```
+  }];
+  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
+                       Variadic<IntegerIndexOrOpaqueType>:$indices);
+  let results = (outs AnyType:$result);
+
+  let builders = [
+    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    }]>
+  ];
+
+  let hasVerifier = 1;
+  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+}
+
+
 #endif // MLIR_DIALECT_EMITC_IR_EMITC

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 9426bbbe2370f0..e401a83bcb42e6 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
 LogicalResult emitc::AssignOp::verify() {
   Value variable = getVar();
   Operation *variableDef = variable.getDefiningOp();
-  if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+  if (!variableDef ||
+      !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
     return emitOpError() << "requires first operand (" << variable
-                         << ") to be a Variable";
+                         << ") to be a Variable or subscript";
 
   Value value = getValue();
   if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SubscriptOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult emitc::SubscriptOp::verify() {
+  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
+    return emitOpError() << "requires number of indices ("
+                         << getIndices().size()
+                         << ") to match the rank of the array type ("
+                         << getArray().getType().getRank() << ")";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index da2e6e5891da63..bc49d7cd67126e 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -174,6 +174,9 @@ struct CppEmitter {
   /// Return the existing or a new name for a Value.
   StringRef getOrCreateName(Value val);
 
+  // Returns the textual representation of a subscript operation.
+  std::string getSubscriptName(emitc::SubscriptOp op);
+
   /// Return the existing or a new label of a Block.
   StringRef getOrCreateName(Block &block);
 
@@ -343,8 +346,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::AssignOp assignOp) {
-  auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
-  OpResult result = variableOp->getResult(0);
+  OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
 
   if (failed(emitter.emitVariableAssignment(result)))
     return failure();
@@ -352,6 +354,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return emitter.emitOperand(assignOp.getValue());
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::SubscriptOp subscriptOp) {
+  // Add name to cache so that `hasValueInScope` works.
+  emitter.getOrCreateName(subscriptOp.getResult());
+  return success();
+}
+
 static LogicalResult printBinaryOperation(CppEmitter &emitter,
                                           Operation *operation,
                                           StringRef binaryOperator) {
@@ -1093,12 +1102,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
   labelInScopeCount.push(0);
 }
 
+std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
+  std::string out;
+  llvm::raw_string_ostream ss(out);
+  ss << getOrCreateName(op.getArray());
+  for (auto index : op.getIndices()) {
+    ss << "[" << getOrCreateName(index) << "]";
+  }
+  return out;
+}
+
 /// Return the existing or a new name for a Value.
 StringRef CppEmitter::getOrCreateName(Value val) {
   if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
     return literal.getValue();
-  if (!valueMapper.count(val))
-    valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+  if (!valueMapper.count(val)) {
+    if (auto subscript =
+            dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
+      valueMapper.insert(val, getSubscriptName(subscript));
+    } else {
+      valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
+    }
+  }
   return *valueMapper.begin(val);
 }
 
@@ -1338,6 +1363,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
 
 LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
                                                   bool trailingSemicolon) {
+  if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
+    return success();
   if (hasValueInScope(result)) {
     return result.getDefiningOp()->emitError(
         "result variable for the operation already declared");
@@ -1413,7 +1440,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
                 emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
                 emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
-                emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
                 emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
                 emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
@@ -1428,7 +1455,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   if (failed(status))
     return failure();
 
-  if (isa<emitc::LiteralOp>(op))
+  if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
     return success();
 
   if (getEmittedExpression() ||

diff  --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 58b3a11ed93e15..6294c853d99931 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
 // -----
 
 func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
-  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
+  // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
   emitc.assign %arg1 : f32 to %arg2 : f32
   return
 }
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
   %0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
   return
 }
+
+// -----
+
+func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+  return
+}

diff  --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
new file mode 100644
index 00000000000000..a6c82df9111a79
--- /dev/null
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+
+emitc.func @func1(%arg0 : f32) {
+  emitc.return
+}
+
+emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
+                     %k: i8) {
+  %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
+  %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+
+  emitc.call @func1 (%0) : (f32) -> ()
+  emitc.call_opaque "func2" (%1) : (f32) -> ()
+  emitc.call_opaque "func3" (%0, %1) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+  emitc.return
+}
+// CHECK: void call_arg(float [[ARR1:[^ ]*]][4][8], int32_t [[I:[^ ]*]],
+// CHECK-SAME:          int16_t [[J:[^ ]*]], int8_t [[K:[^ ]*]])
+// CHECK-NEXT: func1([[ARR1]][[[I]]][[[J]]]);
+// CHECK-NEXT: func2([[ARR1]][[[J]]][[[K]]]);
+// CHECK-NEXT: func3([[ARR1]][[[J]]][[[K]]], [[ARR1]][[[I]]][[[J]]]);


        


More information about the Mlir-commits mailing list