[Mlir-commits] [mlir] [mlir][EmitC] Support pointer-based memrefs in load/store lowering (PR #186828)

ioana ghiban llvmlistbot at llvm.org
Thu Mar 19 06:25:50 PDT 2026


https://github.com/ioghiban updated https://github.com/llvm/llvm-project/pull/186828

>From f7c3ea2db0591e32f209943061008197fb00c781 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Mon, 16 Mar 2026 14:43:12 +0100
Subject: [PATCH 1/2] [mlir][EmitC] Support pointer-based memrefs in load/store
 lowering

Assisted-by: ChatGPT (refine implementation + tests). I reviewed all code and tests before submission.
---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 97 ++++++++++++++++---
 .../memref-to-emitc-alloc-load-store.mlir     | 75 ++++++++++++++
 2 files changed, 159 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index c0c0d87289704..c15a63e3a09e9 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -140,6 +140,41 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
   return ptr;
 }
 
+static Value stripPointerUnrealizedCast(Value v) {
+  if (auto cast = v.getDefiningOp<UnrealizedConversionCastOp>())
+    if (cast.getNumOperands() == 1 &&
+        isa<emitc::PointerType>(cast.getOperand(0).getType()))
+      return cast.getOperand(0); // Pointer path
+  return Value();                // Array path
+}
+
+static Value computeRowMajorLinearIndex(Location loc, MemRefType memrefType,
+                                        ValueRange indices,
+                                        OpBuilder &builder) {
+  ArrayRef<int64_t> shape = memrefType.getShape();
+
+  Type idxType =
+      indices.empty() ? builder.getIndexType() : indices[0].getType();
+
+  Value linearIndex = indices.empty()
+                          ? emitc::ConstantOp::create(builder, loc, idxType,
+                                                      builder.getIndexAttr(0))
+                          : indices[0];
+
+  for (size_t i = 1; i < shape.size(); ++i) {
+    Value dimSize = emitc::ConstantOp::create(builder, loc, idxType,
+                                              builder.getIndexAttr(shape[i]));
+
+    linearIndex =
+        emitc::MulOp::create(builder, loc, idxType, linearIndex, dimSize);
+
+    linearIndex =
+        emitc::AddOp::create(builder, loc, idxType, linearIndex, indices[i]);
+  }
+
+  return linearIndex;
+}
+
 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -347,15 +382,32 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
 
     auto arrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
-    if (!arrayValue) {
-      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+    Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
+    if (!strippedPtr && arrayValue) {
+      auto subscript = emitc::SubscriptOp::create(
+          rewriter, op.getLoc(), arrayValue, operands.getIndices());
+
+      rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
+      return success();
     }
+    if (strippedPtr) {
+      Location loc = op.getLoc();
+      MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
+      ValueRange indices = operands.getIndices();
 
-    auto subscript = emitc::SubscriptOp::create(
-        rewriter, op.getLoc(), arrayValue, operands.getIndices());
+      // Compute row-major linear index
+      Value linearIndex =
+          computeRowMajorLinearIndex(loc, opMemrefType, indices, rewriter);
 
-    rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
-    return success();
+      auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
+      auto subscript = emitc::SubscriptOp::create(rewriter, op.getLoc(),
+                                                  typedPtr, linearIndex);
+
+      rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
+      return success();
+    }
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "expected array or pointer type");
   }
 };
 
@@ -367,17 +419,36 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto arrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
-    if (!arrayValue) {
-      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+    Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
+    if (!strippedPtr && arrayValue) {
+      auto subscript = emitc::SubscriptOp::create(
+          rewriter, op.getLoc(), arrayValue, operands.getIndices());
+      rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                   operands.getValue());
+      return success();
     }
 
-    auto subscript = emitc::SubscriptOp::create(
-        rewriter, op.getLoc(), arrayValue, operands.getIndices());
-    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
-                                                 operands.getValue());
-    return success();
+    if (strippedPtr) {
+      Location loc = op.getLoc();
+      MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
+      ValueRange indices = operands.getIndices();
+
+      // Compute row-major linear index
+      Value linearIndex =
+          computeRowMajorLinearIndex(loc, opMemrefType, indices, rewriter);
+      auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
+      auto subscript =
+          emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
+
+      rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                   operands.getValue());
+      return success();
+    }
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "expected array or pointer type");
   }
 };
+
 } // namespace
 
 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir
new file mode 100644
index 0000000000000..9606a3ca7ee78
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-to-emitc %s -split-input-file \
+// RUN: | FileCheck %s
+
+/// NOTE: This test intentionally uses `-convert-to-emitc` only.
+///
+/// The `-convert-memref-to-emitc` pass introduces
+/// `builtin.unrealized_conversion_cast` operations when lowering
+/// `memref.alloc` results (which are lowered to `emitc.ptr`) to the canonical
+/// memref representation used by the type converter (`emitc.array`).
+/// These casts are expected at that stage of the pipeline.
+///
+/// The purpose of this test is to verify the final lowering produced by
+/// `-convert-to-emitc`, where `memref.load` and `memref.store` conversions now
+/// handle pointer-backed buffers directly and eliminate the intermediate
+/// `unrealized_conversion_cast`.
+/// Therefore, the test must run the full EmitC conversion pipeline.
+
+/// AllocOp conversion always returns a ptr
+// CHECK-LABEL: emitc.func private @memref_alloc_store(
+// CHECK-SAME:  %[[VAL:.*]]: f32,
+// CHECK-SAME:  %[[ARG_I:.*]]: !emitc.size_t,
+// CHECK-SAME:  %[[ARG_J:.*]]: !emitc.size_t)
+func.func private @memref_alloc_store(%v : f32, %i: index, %j: index) {
+  /// Size to alloc computation
+  // CHECK:     %[[SIZEOF_F32:.*]] = call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+  // CHECK:     %[[NUM_ELEMS:.*]] = "emitc.constant"() <{value = 32 : index}> : () -> index
+  // CHECK:     %[[TOTAL_BYTES:.*]] = mul %[[SIZEOF_F32]], %[[NUM_ELEMS]] : (!emitc.size_t, index) -> !emitc.size_t
+  // CHECK:     %[[MALLOC_PTR:.*]] = call_opaque "malloc"(%[[TOTAL_BYTES]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+  // CHECK:     %[[ELEM_PTR:.*]] = cast %[[MALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
+  %alloc = memref.alloc() : memref<4x8xf32>
+  /// Subscript computation
+  // CHECK:     %[[ROW_STRIDE:.*]] = "emitc.constant"() <{value = 8 : index}> : () -> !emitc.size_t
+  // CHECK:     %[[ROW_OFFSET:.*]] = mul %[[ARG_I]], %[[ROW_STRIDE]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK:     %[[LINEAR_INDEX:.*]] = add %[[ROW_OFFSET]], %[[ARG_J]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK:     %[[ELEM_LVALUE:.*]] = subscript %[[ELEM_PTR]]{{\[}}%[[LINEAR_INDEX]]] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     assign %[[VAL]] : f32 to %[[ELEM_LVALUE]] : <f32>
+  memref.store %v, %alloc[%i, %j] : memref<4x8xf32>
+  return
+}
+// CHECK-LABEL: emitc.func private @memref_alloc_load(
+// CHECK-SAME:  %[[ARG_I:.*]]: !emitc.size_t,
+// CHECK-SAME:  %[[ARG_J:.*]]: !emitc.size_t) -> f32
+func.func private @memref_alloc_load(%i: index, %j: index) -> f32 {
+  // CHECK:     %[[SIZEOF_F32:.*]] = call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+  // CHECK:     %[[NUM_ELEMS:.*]] = "emitc.constant"() <{value = 32 : index}> : () -> index
+  // CHECK:     %[[TOTAL_BYTES:.*]] = mul %[[SIZEOF_F32]], %[[NUM_ELEMS]] : (!emitc.size_t, index) -> !emitc.size_t
+  // CHECK:     %[[MALLOC_PTR:.*]] = call_opaque "malloc"(%[[TOTAL_BYTES]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
+  // CHECK:     %[[ELEM_PTR:.*]] = cast %[[MALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
+  %alloc = memref.alloc() : memref<4x8xf32>
+  // CHECK:     %[[ROW_STRIDE:.*]] = "emitc.constant"() <{value = 8 : index}> : () -> !emitc.size_t
+  // CHECK:     %[[ROW_OFFSET:.*]] = mul %[[ARG_I]], %[[ROW_STRIDE]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK:     %[[LINEAR_INDEX:.*]] = add %[[ROW_OFFSET]], %[[ARG_J]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
+  // CHECK:     %[[ELEM_LVALUE:.*]] = subscript %[[ELEM_PTR]]{{\[}}%[[LINEAR_INDEX]]] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     %[[LOADED_VAL:.*]] = load %[[ELEM_LVALUE]] : <f32>
+  %v = memref.load %alloc[%i, %j] : memref<4x8xf32>
+  return %v : f32
+}
+
+/// LoadOp and StoreOp still compatible
+/// Previous array paths still available
+// CHECK-LABEL: emitc.func @memref_load_store(
+// CHECK-SAME:  %[[BUFF0:.*]]: !emitc.array<2xf32>,
+// CHECK-SAME:  %[[BUFF1:.*]]: !emitc.array<4x8xf32>,
+// CHECK-SAME:  %[[ARG_I:.*]]: !emitc.size_t,
+// CHECK-SAME:  %[[ARG_J:.*]]: !emitc.size_t) {
+func.func @memref_load_store(%buff0: memref<2xf32>,
+  %buff1: memref<4x8xf32>, %i : index, %j : index) {
+  // CHECK:     %[[ELEM_LVALUE0:.*]] = subscript %[[BUFF0]]{{\[}}%[[ARG_I]]] : (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     %[[VAL:.*]] = load %[[ELEM_LVALUE0]] : <f32>
+  %v = memref.load %buff0[%i] : memref<2xf32>
+  // CHECK:     %[[ELEM_LVALUE1:.*]] = subscript %[[BUFF1]]{{\[}}%[[ARG_I]], %[[ARG_J]]] : (!emitc.array<4x8xf32>, !emitc.size_t, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     assign %[[VAL]] : f32 to %[[ELEM_LVALUE1]] : <f32>
+  memref.store %v, %buff1[%i, %j] : memref<4x8xf32>
+  return
+}

>From 8cfc3b71abcc83d8f0157b40e6839bb168f1cf11 Mon Sep 17 00:00:00 2001
From: Ioana Ghiban <ioana.ghiban at arm.com>
Date: Thu, 19 Mar 2026 09:49:40 +0100
Subject: [PATCH 2/2] Address first round of comments

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 102 ++++++++----------
 .../memref-to-emitc-alloc-load-store.mlir     |  15 ++-
 2 files changed, 53 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index c15a63e3a09e9..87125b39e4666 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -140,38 +140,35 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
   return ptr;
 }
 
+// If `v` is defined through an unrealized cast and the source of that cast
+// is `emitc.ptr`, return the pointer.
 static Value stripPointerUnrealizedCast(Value v) {
   if (auto cast = v.getDefiningOp<UnrealizedConversionCastOp>())
     if (cast.getNumOperands() == 1 &&
         isa<emitc::PointerType>(cast.getOperand(0).getType()))
-      return cast.getOperand(0); // Pointer path
-  return Value();                // Array path
+      return cast.getOperand(0);
+  return Value();
 }
 
-static Value computeRowMajorLinearIndex(Location loc, MemRefType memrefType,
-                                        ValueRange indices,
-                                        OpBuilder &builder) {
+static Value computeRowMajorLinearIndex(ImplicitLocOpBuilder &builder,
+                                        MemRefType memrefType,
+                                        ValueRange indices) {
   ArrayRef<int64_t> shape = memrefType.getShape();
 
   Type idxType =
       indices.empty() ? builder.getIndexType() : indices[0].getType();
 
-  Value linearIndex = indices.empty()
-                          ? emitc::ConstantOp::create(builder, loc, idxType,
-                                                      builder.getIndexAttr(0))
-                          : indices[0];
+  Value linearIndex =
+      indices.empty()
+          ? emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(0))
+          : indices[0];
 
-  for (size_t i = 1; i < shape.size(); ++i) {
-    Value dimSize = emitc::ConstantOp::create(builder, loc, idxType,
-                                              builder.getIndexAttr(shape[i]));
-
-    linearIndex =
-        emitc::MulOp::create(builder, loc, idxType, linearIndex, dimSize);
-
-    linearIndex =
-        emitc::AddOp::create(builder, loc, idxType, linearIndex, indices[i]);
+  for (auto [dim, idx] : llvm::zip(shape.drop_front(), indices.drop_front())) {
+    Value dimSize =
+        emitc::ConstantOp::create(builder, idxType, builder.getIndexAttr(dim));
+    linearIndex = emitc::MulOp::create(builder, idxType, linearIndex, dimSize);
+    linearIndex = emitc::AddOp::create(builder, idxType, linearIndex, idx);
   }
-
   return linearIndex;
 }
 
@@ -374,40 +371,36 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
   LogicalResult
   matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-
+    Location loc = op.getLoc();
     auto resultTy = getTypeConverter()->convertType(op.getType());
     if (!resultTy) {
-      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+      return rewriter.notifyMatchFailure(loc, "cannot convert type");
     }
 
     auto arrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
     Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
     if (!strippedPtr && arrayValue) {
-      auto subscript = emitc::SubscriptOp::create(
-          rewriter, op.getLoc(), arrayValue, operands.getIndices());
+      auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
+                                                  operands.getIndices());
 
       rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
       return success();
     }
-    if (strippedPtr) {
-      Location loc = op.getLoc();
-      MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
-      ValueRange indices = operands.getIndices();
 
-      // Compute row-major linear index
-      Value linearIndex =
-          computeRowMajorLinearIndex(loc, opMemrefType, indices, rewriter);
+    if (!strippedPtr)
+      return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
+    MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
+    ValueRange indices = operands.getIndices();
 
-      auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
-      auto subscript = emitc::SubscriptOp::create(rewriter, op.getLoc(),
-                                                  typedPtr, linearIndex);
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
+    auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
+    auto subscript =
+        emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
 
-      rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
-      return success();
-    }
-    return rewriter.notifyMatchFailure(op.getLoc(),
-                                       "expected array or pointer type");
+    rewriter.replaceOpWithNewOp<emitc::LoadOp>(op, resultTy, subscript);
+    return success();
   }
 };
 
@@ -417,35 +410,32 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
   LogicalResult
   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
     auto arrayValue =
         dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
     Value strippedPtr = stripPointerUnrealizedCast(operands.getMemref());
     if (!strippedPtr && arrayValue) {
-      auto subscript = emitc::SubscriptOp::create(
-          rewriter, op.getLoc(), arrayValue, operands.getIndices());
+      auto subscript = emitc::SubscriptOp::create(rewriter, loc, arrayValue,
+                                                  operands.getIndices());
       rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
                                                    operands.getValue());
       return success();
     }
 
-    if (strippedPtr) {
-      Location loc = op.getLoc();
-      MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
-      ValueRange indices = operands.getIndices();
+    if (!strippedPtr)
+      return rewriter.notifyMatchFailure(loc, "expected array or pointer type");
+    MemRefType opMemrefType = cast<MemRefType>(op.getMemref().getType());
+    ValueRange indices = operands.getIndices();
 
-      // Compute row-major linear index
-      Value linearIndex =
-          computeRowMajorLinearIndex(loc, opMemrefType, indices, rewriter);
-      auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
-      auto subscript =
-          emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
+    ImplicitLocOpBuilder b(loc, rewriter);
+    Value linearIndex = computeRowMajorLinearIndex(b, opMemrefType, indices);
+    auto typedPtr = cast<TypedValue<emitc::PointerType>>(strippedPtr);
+    auto subscript =
+        emitc::SubscriptOp::create(rewriter, loc, typedPtr, linearIndex);
 
-      rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
-                                                   operands.getValue());
-      return success();
-    }
-    return rewriter.notifyMatchFailure(op.getLoc(),
-                                       "expected array or pointer type");
+    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                 operands.getValue());
+    return success();
   }
 };
 
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir
index 9606a3ca7ee78..4b396005a7da3 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-alloc-load-store.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt -convert-to-emitc %s -split-input-file \
-// RUN: | FileCheck %s
+// RUN: mlir-opt -convert-to-emitc %s | FileCheck %s
 
 /// NOTE: This test intentionally uses `-convert-to-emitc` only.
 ///
@@ -21,18 +20,20 @@
 // CHECK-SAME:  %[[ARG_I:.*]]: !emitc.size_t,
 // CHECK-SAME:  %[[ARG_J:.*]]: !emitc.size_t)
 func.func private @memref_alloc_store(%v : f32, %i: index, %j: index) {
-  /// Size to alloc computation
+  /// Allocation size  computation
   // CHECK:     %[[SIZEOF_F32:.*]] = call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
   // CHECK:     %[[NUM_ELEMS:.*]] = "emitc.constant"() <{value = 32 : index}> : () -> index
   // CHECK:     %[[TOTAL_BYTES:.*]] = mul %[[SIZEOF_F32]], %[[NUM_ELEMS]] : (!emitc.size_t, index) -> !emitc.size_t
+  /// Alloc
   // CHECK:     %[[MALLOC_PTR:.*]] = call_opaque "malloc"(%[[TOTAL_BYTES]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
   // CHECK:     %[[ELEM_PTR:.*]] = cast %[[MALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
   %alloc = memref.alloc() : memref<4x8xf32>
-  /// Subscript computation
+  /// Store subscript computation
   // CHECK:     %[[ROW_STRIDE:.*]] = "emitc.constant"() <{value = 8 : index}> : () -> !emitc.size_t
   // CHECK:     %[[ROW_OFFSET:.*]] = mul %[[ARG_I]], %[[ROW_STRIDE]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
   // CHECK:     %[[LINEAR_INDEX:.*]] = add %[[ROW_OFFSET]], %[[ARG_J]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
   // CHECK:     %[[ELEM_LVALUE:.*]] = subscript %[[ELEM_PTR]]{{\[}}%[[LINEAR_INDEX]]] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+  /// Store
   // CHECK:     assign %[[VAL]] : f32 to %[[ELEM_LVALUE]] : <f32>
   memref.store %v, %alloc[%i, %j] : memref<4x8xf32>
   return
@@ -61,14 +62,12 @@ func.func private @memref_alloc_load(%i: index, %j: index) -> f32 {
 // CHECK-LABEL: emitc.func @memref_load_store(
 // CHECK-SAME:  %[[BUFF0:.*]]: !emitc.array<2xf32>,
 // CHECK-SAME:  %[[BUFF1:.*]]: !emitc.array<4x8xf32>,
-// CHECK-SAME:  %[[ARG_I:.*]]: !emitc.size_t,
-// CHECK-SAME:  %[[ARG_J:.*]]: !emitc.size_t) {
 func.func @memref_load_store(%buff0: memref<2xf32>,
   %buff1: memref<4x8xf32>, %i : index, %j : index) {
-  // CHECK:     %[[ELEM_LVALUE0:.*]] = subscript %[[BUFF0]]{{\[}}%[[ARG_I]]] : (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     %[[ELEM_LVALUE0:.*]] = subscript %[[BUFF0]]{{.*}} (!emitc.array<2xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
   // CHECK:     %[[VAL:.*]] = load %[[ELEM_LVALUE0]] : <f32>
   %v = memref.load %buff0[%i] : memref<2xf32>
-  // CHECK:     %[[ELEM_LVALUE1:.*]] = subscript %[[BUFF1]]{{\[}}%[[ARG_I]], %[[ARG_J]]] : (!emitc.array<4x8xf32>, !emitc.size_t, !emitc.size_t) -> !emitc.lvalue<f32>
+  // CHECK:     %[[ELEM_LVALUE1:.*]] = subscript %[[BUFF1]]{{.*}} (!emitc.array<4x8xf32>, !emitc.size_t, !emitc.size_t) -> !emitc.lvalue<f32>
   // CHECK:     assign %[[VAL]] : f32 to %[[ELEM_LVALUE1]] : <f32>
   memref.store %v, %buff1[%i, %j] : memref<4x8xf32>
   return



More information about the Mlir-commits mailing list