[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `extract_strided_metadata` (PR #152208)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 5 14:35:17 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
Author: Jaden Angella (Jaddyen)
<details>
<summary>Changes</summary>
This patch lowers `memref.extract_strided_metadata` to a pointer to the first element of the array, the offset, the strides and sizes.
From:
```
func.func @<!-- -->copying(%arg18: memref<1xi32>) {
%base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
return
}
```
To:
```cpp
void extract_strided_metadata(int32_t v1[1]) {
size_t v2 = 0;
int32_t* v3 = &v1[v2];
size_t v4 = 0;
size_t v5 = 1;
size_t v6 = 1;
return;
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/152208.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+69-2)
- (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+16)
``````````diff
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..428cdb0c1425a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,10 +16,12 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
@@ -288,6 +290,70 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};
+
+struct ConvertExtractStridedMetadata final
+ : public OpConversionPattern<memref::ExtractStridedMetadataOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+ OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = extractStridedMetadataOp.getLoc();
+ Value source = extractStridedMetadataOp.getSource();
+
+ MemRefType memrefType = cast<MemRefType>(source.getType());
+ if (!isMemRefTypeLegalForEmitC(memrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+
+ emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+ TypedValue<emitc::ArrayType> srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
+ srcArrayValue]() -> emitc::ApplyOp {
+ int64_t rank = srcArrayValue.getType().getRank();
+ llvm::SmallVector<mlir::Value> indices;
+ for (int i = 0; i < rank; ++i) {
+ indices.push_back(zeroIndex);
+ }
+
+ emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, srcArrayValue, mlir::ValueRange(indices));
+ emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+ loc,
+ emitc::PointerType::get(srcArrayValue.getType().getElementType()),
+ rewriter.getStringAttr("&"), subPtr);
+
+ return ptr;
+ };
+
+ emitc::ApplyOp srcPtr = createPointerFromEmitcArray();
+ auto [strides, offset] = memrefType.getStridesAndOffset();
+ Value offsetValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+ SmallVector<Value> results;
+ results.push_back(srcPtr);
+ results.push_back(offsetValue);
+
+ for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value sizeValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIndexAttr(memrefType.getDimSize(i)));
+ results.push_back(sizeValue);
+
+ Value strideValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i]));
+ results.push_back(strideValue);
+ }
+
+ rewriter.replaceOp(extractStridedMetadataOp, results);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -320,6 +386,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
+ ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..d36eaf3c2673a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,19 @@ module @globals {
return
}
}
+
+// -----
+
+// CHECK-LABEL: reinterpret_cast
+func.func @reinterpret_cast(%arg18: memref<1xi32>) {
+ // CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32>
+ // CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+ // CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue<i32>
+ // CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ // CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index
+ // CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index
+ // CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index
+ %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
+ return
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/152208
More information about the Mlir-commits
mailing list