[Mlir-commits] [mlir] [mlir][EmitC] Add support for pointer and opaque types to subscript op (PR #86266)

Simon Camphausen llvmlistbot at llvm.org
Tue Apr 2 05:34:46 PDT 2024


https://github.com/simon-camp updated https://github.com/llvm/llvm-project/pull/86266

>From d7e255ab8e580661c351b4ad3b3fec45c37a15c1 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 10:28:28 +0000
Subject: [PATCH 1/7] [mlir][EmitC] Add support for pointer and opaque types to
 subscript op.

For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |  6 ++
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 30 ++++++----
 .../MemRefToEmitC/MemRefToEmitC.cpp           |  6 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 56 +++++++++++++++++--
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  2 +-
 .../MemRefToEmitC/memref-to-emitc.mlir        |  4 +-
 mlir/test/Dialect/EmitC/invalid_ops.mlir      | 44 ++++++++++++++-
 mlir/test/Dialect/EmitC/ops.mlir              |  7 +++
 mlir/test/Target/Cpp/subscript.mlir           | 32 +++++++++--
 9 files changed, 157 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bcb4e6cb1..d2f20b642b26b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
 namespace mlir {
 namespace emitc {
 void buildTerminatedBody(OpBuilder &builder, Location loc);
+
 /// Determines whether \p type is a valid integer type in EmitC.
 bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerLikeType(Type type);
+
 /// Determines whether \p type is a valid floating-point type in EmitC.
 bool isSupportedFloatType(mlir::Type type);
 } // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..539a4f3e9805e1 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
   let hasCustomAssemblyFormat = 1;
 }
 
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
-  [TypesMatchWith<"result type matches element type of 'array'",
-                  "array", "result",
-                  "::llvm::cast<ArrayType>($_self).getElementType()">]> {
-  let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+  let summary = "Subscript operation";
   let description = [{
     With the `subscript` operation the subscript operator `[]` can be applied
-    to variables or arguments of array type.
+    to variables or arguments of array, pointer and opaque type.
 
     Example:
 
     ```mlir
     %i = index.constant 1
     %j = index.constant 7
-    %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+    %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+    %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
     ```
   }];
-  let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
-                       Variadic<IntegerIndexOrOpaqueType>:$indices);
+  let arguments = (ins Arg<AnyTypeOf<[
+      EmitC_ArrayType,
+      EmitC_OpaqueType,
+      EmitC_PointerType]>,
+    "the reference to load from">:$ref,
+    Variadic<AnyType>:$indices);
   let results = (outs AnyType:$result);
 
   let builders = [
-    OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
-      build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+    OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+      build($_builder, $_state, array.getType().getElementType(), array, indices);
+    }]>,
+    OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+      build($_builder, $_state, pointer.getType().getPointee(), pointer,
+            ValueRange{index});
     }]>
   ];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+  let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
 
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..3a2405a6195437 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
     }
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
 
     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
     auto var =
@@ -83,7 +84,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), operands.getMemref(), operands.getIndices());
+        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+        operands.getIndices());
     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
                                                  operands.getValue());
     return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f364573552fe97 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
   return false;
 }
 
+bool mlir::emitc::isIntegerLikeType(Type type) {
+  return isSupportedIntegerType(type) ||
+         llvm::isa<IndexType, emitc::OpaqueType>(type);
+}
+
 bool mlir::emitc::isSupportedFloatType(Type type) {
   if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
     switch (floatType.getWidth()) {
@@ -781,11 +786,52 @@ LogicalResult emitc::YieldOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult emitc::SubscriptOp::verify() {
-  if (getIndices().size() != (size_t)getArray().getType().getRank()) {
-    return emitOpError() << "requires number of indices ("
-                         << getIndices().size()
-                         << ") to match the rank of the array type ("
-                         << getArray().getType().getRank() << ")";
+  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != (size_t)arrayType.getRank()) {
+      return emitOpError() << "requires number of indices ("
+                           << getIndices().size()
+                           << ") to match the rank of the array type ("
+                           << arrayType.getRank() << ")";
+    }
+    // Check types of index operands.
+    for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+      Type type = getIndices()[i].getType();
+      if (!isIntegerLikeType(type)) {
+        return emitOpError() << "requires index operand " << i
+                             << " to be integer-like, but got " << type;
+      }
+    }
+    // Check element type.
+    Type elementType = arrayType.getElementType();
+    if (elementType != getType()) {
+      return emitOpError() << "requires element type (" << elementType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else if (auto pointerType =
+                 llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
+    // Check arity of indices.
+    if (getIndices().size() != 1) {
+      return emitOpError() << "requires one index operand, but got "
+                           << getIndices().size();
+    }
+    // Check types of index operand.
+    Type type = getIndices()[0].getType();
+    if (!isIntegerLikeType(type)) {
+      return emitOpError()
+             << "requires index operand to be integer-like, but got " << type;
+    }
+    // Check pointee type.
+    Type pointeeType = pointerType.getPointee();
+    if (pointeeType != getType()) {
+      return emitOpError() << "requires pointee type (" << pointeeType
+                           << ") and result type (" << getType()
+                           << ") to match";
+    }
+  } else {
+    // The reference has opaque type, so we can't assume anything about arity or
+    // types of index operands.
   }
   return success();
 }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..8fd04b7d1a51e0 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   std::string out;
   llvm::raw_string_ostream ss(out);
-  ss << getOrCreateName(op.getArray());
+  ss << getOrCreateName(op.getRef());
   for (auto index : op.getIndices()) {
     ss << "[" << getOrCreateName(index) << "]";
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..7aa2ba88843a2a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
   memref.store %v, %0[%i, %j] : memref<4x8xf32>
   return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
 
-  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
   // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
   %1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..321e4c01110e82 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
 
 // -----
 
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
   // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
-  %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+  %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+  %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+  %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+  // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+  return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+  // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..ace3670426afa5 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
   return
 }
 
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+  %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+  %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  return
+}
+
 emitc.verbatim "#ifdef __cplusplus"
 emitc.verbatim "extern \"C\" {"
 emitc.verbatim "#endif  // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df9111a79..0b388953c80d37 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
 // RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
 
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
-  %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
-  %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+  %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
   emitc.assign %0 : f32 to %1 : f32
   return
 }
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
 // CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
 // CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
 
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+  emitc.assign %0 : f32 to %1 : f32
+  return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME:            size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+  %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+  emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+  return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME:            char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
 emitc.func @func1(%arg0 : f32) {
   emitc.return
 }
 
 emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
                      %k: i8) {
-  %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
-  %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+  %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+  %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
 
   emitc.call @func1 (%0) : (f32) -> ()
   emitc.call_opaque "func2" (%1) : (f32) -> ()

>From fa15d981e78266905b605b15d37fb2f4485d3109 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 12:52:04 +0000
Subject: [PATCH 2/7] Review feedback

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h  |  2 +-
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td |  4 ++--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 14 +++++++-------
 mlir/lib/Target/Cpp/TranslateToCpp.cpp      |  2 +-
 4 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index d2f20b642b26b2..c03915667db653 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -36,7 +36,7 @@ bool isSupportedIntegerType(mlir::Type type);
 
 /// Determines whether \p type is integer like, i.e. it's a supported integer,
 /// an index or opaque type.
-bool isIntegerLikeType(Type type);
+bool isIntegerIndexOrOpaqueType(Type type);
 
 /// Determines whether \p type is a valid floating-point type in EmitC.
 bool isSupportedFloatType(mlir::Type type);
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 539a4f3e9805e1..090dae8a6aef3d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1174,7 +1174,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
       EmitC_ArrayType,
       EmitC_OpaqueType,
       EmitC_PointerType]>,
-    "the reference to load from">:$ref,
+    "the value to subscript">:$value,
     Variadic<AnyType>:$indices);
   let results = (outs AnyType:$result);
 
@@ -1189,7 +1189,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
   ];
 
   let hasVerifier = 1;
-  let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
+  let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f364573552fe97..5272e81dfa4d75 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,7 +70,7 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
   return false;
 }
 
-bool mlir::emitc::isIntegerLikeType(Type type) {
+bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
   return isSupportedIntegerType(type) ||
          llvm::isa<IndexType, emitc::OpaqueType>(type);
 }
@@ -786,8 +786,8 @@ LogicalResult emitc::YieldOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult emitc::SubscriptOp::verify() {
-  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
-    // Check arity of indices.
+  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
+    // Check number of indices.
     if (getIndices().size() != (size_t)arrayType.getRank()) {
       return emitOpError() << "requires number of indices ("
                            << getIndices().size()
@@ -797,7 +797,7 @@ LogicalResult emitc::SubscriptOp::verify() {
     // Check types of index operands.
     for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
       Type type = getIndices()[i].getType();
-      if (!isIntegerLikeType(type)) {
+      if (!isIntegerIndexOrOpaqueType(type)) {
         return emitOpError() << "requires index operand " << i
                              << " to be integer-like, but got " << type;
       }
@@ -810,15 +810,15 @@ LogicalResult emitc::SubscriptOp::verify() {
                            << ") to match";
     }
   } else if (auto pointerType =
-                 llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
-    // Check arity of indices.
+                 llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
+    // Check number of indices.
     if (getIndices().size() != 1) {
       return emitOpError() << "requires one index operand, but got "
                            << getIndices().size();
     }
     // Check types of index operand.
     Type type = getIndices()[0].getType();
-    if (!isIntegerLikeType(type)) {
+    if (!isIntegerIndexOrOpaqueType(type)) {
       return emitOpError()
              << "requires index operand to be integer-like, but got " << type;
     }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8fd04b7d1a51e0..d071efdfac0fd3 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
 std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
   std::string out;
   llvm::raw_string_ostream ss(out);
-  ss << getOrCreateName(op.getRef());
+  ss << getOrCreateName(op.getValue());
   for (auto index : op.getIndices()) {
     ss << "[" << getOrCreateName(index) << "]";
   }

>From f42375e21abd95d5d365340059d53692107cfb77 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 13:05:45 +0000
Subject: [PATCH 3/7] Reduce control flow nesting

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 20 +++++++++++++-------
 1 file changed, 13 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5272e81dfa4d75..7fbf602a47a7c3 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -71,8 +71,8 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
 }
 
 bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
-  return isSupportedIntegerType(type) ||
-         llvm::isa<IndexType, emitc::OpaqueType>(type);
+  return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
+         isSupportedIntegerType(type);
 }
 
 bool mlir::emitc::isSupportedFloatType(Type type) {
@@ -786,6 +786,7 @@ LogicalResult emitc::YieldOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult emitc::SubscriptOp::verify() {
+  // Checks for array operand.
   if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
     // Check number of indices.
     if (getIndices().size() != (size_t)arrayType.getRank()) {
@@ -809,8 +810,12 @@ LogicalResult emitc::SubscriptOp::verify() {
                            << ") and result type (" << getType()
                            << ") to match";
     }
-  } else if (auto pointerType =
-                 llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
+    return success();
+  }
+
+  // Checks for pointer operand.
+  if (auto pointerType =
+          llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
     // Check number of indices.
     if (getIndices().size() != 1) {
       return emitOpError() << "requires one index operand, but got "
@@ -829,10 +834,11 @@ LogicalResult emitc::SubscriptOp::verify() {
                            << ") and result type (" << getType()
                            << ") to match";
     }
-  } else {
-    // The reference has opaque type, so we can't assume anything about arity or
-    // types of index operands.
+    return success();
   }
+
+  // The operand has opaque type, so we can't assume anything about arity or
+  // types of index operands.
   return success();
 }
 

>From ce8d013ce510010c88a4948247700d574410e9ac Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 13:22:44 +0000
Subject: [PATCH 4/7] Make verification errors more descriptive

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp      | 22 ++++++++++++----------
 mlir/test/Dialect/EmitC/invalid_ops.mlir | 12 ++++++------
 2 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 7fbf602a47a7c3..76421fd2a5d50b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -790,7 +790,7 @@ LogicalResult emitc::SubscriptOp::verify() {
   if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
     // Check number of indices.
     if (getIndices().size() != (size_t)arrayType.getRank()) {
-      return emitOpError() << "requires number of indices ("
+      return emitOpError() << "on array operand requires number of indices ("
                            << getIndices().size()
                            << ") to match the rank of the array type ("
                            << arrayType.getRank() << ")";
@@ -799,15 +799,15 @@ LogicalResult emitc::SubscriptOp::verify() {
     for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
       Type type = getIndices()[i].getType();
       if (!isIntegerIndexOrOpaqueType(type)) {
-        return emitOpError() << "requires index operand " << i
+        return emitOpError() << "on array operand requires index operand " << i
                              << " to be integer-like, but got " << type;
       }
     }
     // Check element type.
     Type elementType = arrayType.getElementType();
     if (elementType != getType()) {
-      return emitOpError() << "requires element type (" << elementType
-                           << ") and result type (" << getType()
+      return emitOpError() << "on array operand requires element type ("
+                           << elementType << ") and result type (" << getType()
                            << ") to match";
     }
     return success();
@@ -818,20 +818,22 @@ LogicalResult emitc::SubscriptOp::verify() {
           llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
     // Check number of indices.
     if (getIndices().size() != 1) {
-      return emitOpError() << "requires one index operand, but got "
-                           << getIndices().size();
+      return emitOpError()
+             << "on pointer operand requires one index operand, but got "
+             << getIndices().size();
     }
     // Check types of index operand.
     Type type = getIndices()[0].getType();
     if (!isIntegerIndexOrOpaqueType(type)) {
-      return emitOpError()
-             << "requires index operand to be integer-like, but got " << type;
+      return emitOpError() << "on pointer operand requires index operand to be "
+                              "integer-like, but got "
+                           << type;
     }
     // Check pointee type.
     Type pointeeType = pointerType.getPointee();
     if (pointeeType != getType()) {
-      return emitOpError() << "requires pointee type (" << pointeeType
-                           << ") and result type (" << getType()
+      return emitOpError() << "on pointer operand requires pointee type ("
+                           << pointeeType << ") and result type (" << getType()
                            << ") to match";
     }
     return success();
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 321e4c01110e82..77868a4b3c5466 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -391,7 +391,7 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
 // -----
 
 func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
-  // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
+  // expected-error @+1 {{'emitc.subscript' op on array operand requires number of indices (1) to match the rank of the array type (2)}}
   %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
   return
 }
@@ -399,7 +399,7 @@ func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %
 // -----
 
 func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
-  // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+  // expected-error @+1 {{'emitc.subscript' op on array operand requires index operand 1 to be integer-like, but got 'f32'}}
   %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
   return
 }
@@ -407,7 +407,7 @@ func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>
 // -----
 
 func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
-  // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+  // expected-error @+1 {{'emitc.subscript' op on array operand requires element type ('f32') and result type ('i32') to match}}
   %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
   return
 }
@@ -415,7 +415,7 @@ func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg
 // -----
 
 func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
-  // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+  // expected-error @+1 {{'emitc.subscript' op on pointer operand requires one index operand, but got 2}}
   %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
   return
 }
@@ -423,7 +423,7 @@ func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: in
 // -----
 
 func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
-  // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+  // expected-error @+1 {{'emitc.subscript' op on pointer operand requires index operand to be integer-like, but got 'f64'}}
   %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
   return
 }
@@ -431,7 +431,7 @@ func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2:
 // -----
 
 func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
-  // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+  // expected-error @+1 {{'emitc.subscript' op on pointer operand requires pointee type ('f32') and result type ('f64') to match}}
   %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
   return
 }

>From 1ee79030c195cebd8e1fe22b8906155a56cc264b Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 13:47:33 +0000
Subject: [PATCH 5/7] Name arguments consistently in lit tests

---
 mlir/test/Dialect/EmitC/invalid_ops.mlir | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 77868a4b3c5466..bbaab0d5b6f3a9 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -414,24 +414,24 @@ func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg
 
 // -----
 
-func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
   // expected-error @+1 {{'emitc.subscript' op on pointer operand requires one index operand, but got 2}}
-  %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+  %0 = emitc.subscript %arg0[%arg1, %arg1] : (!emitc.ptr<f32>, index, index) -> f32
   return
 }
 
 // -----
 
-func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: f64) {
   // expected-error @+1 {{'emitc.subscript' op on pointer operand requires index operand to be integer-like, but got 'f64'}}
-  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+  %0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, f64) -> f32
   return
 }
 
 // -----
 
-func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
   // expected-error @+1 {{'emitc.subscript' op on pointer operand requires pointee type ('f32') and result type ('f64') to match}}
-  %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
+  %0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, index) -> f64
   return
 }

>From 6262aacc1944e658c5e5d1deaac5e96e7ac42f12 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Fri, 22 Mar 2024 14:15:08 +0000
Subject: [PATCH 6/7] Fail gracefully on type mismatch

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.cpp  | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 3a2405a6195437..25fa15892203b3 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -62,9 +62,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
     }
 
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+    if (!arrayValue) {
+      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+    }
+
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
-        operands.getIndices());
+        op.getLoc(), arrayValue, operands.getIndices());
 
     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
     auto var =
@@ -82,10 +87,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
   LogicalResult
   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+    if (!arrayValue) {
+      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+    }
 
     auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
-        operands.getIndices());
+        op.getLoc(), arrayValue, operands.getIndices());
     rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
                                                  operands.getValue());
     return success();

>From 93fff8a50d96f48f2062a96abf2ce062aab33ed9 Mon Sep 17 00:00:00 2001
From: Simon Camphausen <simon.camphausen at iml.fraunhofer.de>
Date: Tue, 2 Apr 2024 12:31:03 +0000
Subject: [PATCH 7/7] Reword comment

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 76421fd2a5d50b..4c895169ae89dc 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -839,8 +839,8 @@ LogicalResult emitc::SubscriptOp::verify() {
     return success();
   }
 
-  // The operand has opaque type, so we can't assume anything about arity or
-  // types of index operands.
+  // The operand has opaque type, so we can't assume anything about the number
+  // or types of index operands.
   return success();
 }
 



More information about the Mlir-commits mailing list