[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `extract_strided_metadata` (PR #152208)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 5 14:35:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-emitc

Author: Jaden Angella (Jaddyen)

<details>
<summary>Changes</summary>

This patch lowers `memref.extract_strided_metadata` to a pointer to the first element of the array, the offset, the strides and sizes.

From:
```
func.func @<!-- -->copying(%arg18: memref<1xi32>) {
  %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
  return
}
```
To:
```cpp
void extract_strided_metadata(int32_t v1[1]) {
  size_t v2 = 0;
  int32_t* v3 = &v1[v2];
  size_t v4 = 0;
  size_t v5 = 1;
  size_t v6 = 1;
  return;
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/152208.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+69-2) 
- (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..428cdb0c1425a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <cstdint>
 
@@ -288,6 +290,70 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
     return success();
   }
 };
+
+struct ConvertExtractStridedMetadata final
+    : public OpConversionPattern<memref::ExtractStridedMetadataOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = extractStridedMetadataOp.getLoc();
+    Value source = extractStridedMetadataOp.getSource();
+
+    MemRefType memrefType = cast<MemRefType>(source.getType());
+    if (!isMemRefTypeLegalForEmitC(memrefType))
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible memref type for EmitC conversion");
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+    TypedValue<emitc::ArrayType> srcArrayValue =
+        cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
+                                        srcArrayValue]() -> emitc::ApplyOp {
+      int64_t rank = srcArrayValue.getType().getRank();
+      llvm::SmallVector<mlir::Value> indices;
+      for (int i = 0; i < rank; ++i) {
+        indices.push_back(zeroIndex);
+      }
+
+      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+          loc, srcArrayValue, mlir::ValueRange(indices));
+      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+          loc,
+          emitc::PointerType::get(srcArrayValue.getType().getElementType()),
+          rewriter.getStringAttr("&"), subPtr);
+
+      return ptr;
+    };
+
+    emitc::ApplyOp srcPtr = createPointerFromEmitcArray();
+    auto [strides, offset] = memrefType.getStridesAndOffset();
+    Value offsetValue = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+    SmallVector<Value> results;
+    results.push_back(srcPtr);
+    results.push_back(offsetValue);
+
+    for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
+      Value sizeValue = rewriter.create<emitc::ConstantOp>(
+          loc, rewriter.getIndexType(),
+          rewriter.getIndexAttr(memrefType.getDimSize(i)));
+      results.push_back(sizeValue);
+
+      Value strideValue = rewriter.create<emitc::ConstantOp>(
+          loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i]));
+      results.push_back(strideValue);
+    }
+
+    rewriter.replaceOp(extractStridedMetadataOp, results);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -320,6 +386,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
+               ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..d36eaf3c2673a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,19 @@ module @globals {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: reinterpret_cast
+func.func @reinterpret_cast(%arg18: memref<1xi32>) {
+  // CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32>
+  // CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+  // CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue<i32>
+  // CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+  // CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index
+  // CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index
+  // CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index
+  %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
+  return
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/152208


More information about the Mlir-commits mailing list