[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `extract_strided_metadata` (PR #152208)
Jaden Angella
llvmlistbot at llvm.org
Tue Aug 5 14:34:53 PDT 2025
https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/152208
>From 6752ed3f1a051b172a5ffe9d0fae8276a741a14e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 31 Jul 2025 22:43:13 +0000
Subject: [PATCH 1/5] initial work on metadata ops
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 92 ++++++++++++++++++-
1 file changed, 90 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..1008fefc65cf0 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,10 +16,12 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
@@ -288,6 +290,90 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};
+
+struct ConvertExtractStridedMetadata final
+ : public OpConversionPattern<memref::ExtractStridedMetadataOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+ OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = extractStridedMetadataOp.getLoc();
+ Value source = extractStridedMetadataOp.getSource();
+
+ MemRefType memrefType = cast<MemRefType>(source.getType());
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type resultType = convertMemRefType(memrefType, getTypeConverter());
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(loc, "cannot convert result type");
+ }
+
+ auto baseptr =
+ cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
+ auto emitcType = convertMemRefType(baseptr, getTypeConverter());
+
+ auto [strides, offset] = memrefType.getStridesAndOffset();
+ Value offsetValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+ SmallVector<Value> results;
+ results.push_back(extractStridedMetadataOp.getBaseBuffer());
+ results.push_back(offsetValue);
+
+ for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value sizeValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIndexAttr(memrefType.getDimSize(i)));
+ results.push_back(sizeValue);
+
+ Value strideValue = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i]));
+ results.push_back(strideValue);
+ }
+
+ rewriter.replaceOp(extractStridedMetadataOp, results);
+ return success();
+ }
+};
+
+struct ConvertReinterpretCastOp
+ : public OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
+
+ MemRefType targetMemRefType =
+ cast<MemRefType>(castOp.getResult().getType());
+
+ auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
+ auto targetInEmitC =
+ convertMemRefType(targetMemRefType, getTypeConverter());
+ if (!srcInEmitC || !targetInEmitC) {
+ return rewriter.notifyMatchFailure(castOp.getLoc(),
+ "cannot convert memref type");
+ }
+
+ // Create descriptor.
+ Location loc = castOp.getLoc();
+
+ auto vals = adaptor.getOperands();
+
+ auto res =
+ UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
+ .getResult(0);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -320,6 +406,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
+ ConvertGlobal, ConvertGetGlobal, ConvertLoad,
+ ConvertReinterpretCastOp, ConvertStore>(converter,
+ patterns.getContext());
}
>From 57b34ae74ff63eebfb3eadb596f983cad0b0e683 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 1 Aug 2025 18:25:00 +0000
Subject: [PATCH 2/5] setting up variables
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 6 +++---
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 6 ++++--
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 6 +++---
3 files changed, 10 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7fe2da8f7e044..d19a32aa39734 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,10 +1191,10 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
- [MemAlloc<DefaultResource, 0, FullEffect>]>);
+ let results = (outs Res<AnyTypeOf<[EmitCType]>,
+ "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
- let hasVerifier = 1;
+ // let hasVerifier = 1;
}
def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 1008fefc65cf0..d836fa0066b7c 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -316,13 +316,15 @@ struct ConvertExtractStridedMetadata final
auto baseptr =
cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
auto emitcType = convertMemRefType(baseptr, getTypeConverter());
-
+ auto arrT = emitc::ArrayType::get(memrefType.getShape(), emitcType);
+ auto valVar = rewriter.create<emitc::VariableOp>(
+ loc, arrT, emitc::OpaqueAttr::get(rewriter.getContext(), ""));
auto [strides, offset] = memrefType.getStridesAndOffset();
Value offsetValue = rewriter.create<emitc::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
SmallVector<Value> results;
- results.push_back(extractStridedMetadataOp.getBaseBuffer());
+ results.push_back(valVar);
results.push_back(offsetValue);
for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c0902293cbf9..87d6f713ea35a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -889,9 +889,9 @@ LogicalResult SubOp::verify() {
// VariableOp
//===----------------------------------------------------------------------===//
-LogicalResult emitc::VariableOp::verify() {
- return verifyInitializationAttribute(getOperation(), getValueAttr());
-}
+// LogicalResult emitc::VariableOp::verify() {
+// return verifyInitializationAttribute(getOperation(), getValueAttr());
+// }
//===----------------------------------------------------------------------===//
// YieldOp
>From 2147fca8a77d0ede8f13e9027357992643512e3e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 21:00:06 +0000
Subject: [PATCH 3/5] separate the ops
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 77 +++++++------------
.../MemRefToEmitC/memref-to-emitc.mlir | 16 ++++
2 files changed, 43 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index d836fa0066b7c..428cdb0c1425a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -303,28 +303,39 @@ struct ConvertExtractStridedMetadata final
Value source = extractStridedMetadataOp.getSource();
MemRefType memrefType = cast<MemRefType>(source.getType());
- if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ if (!isMemRefTypeLegalForEmitC(memrefType))
return rewriter.notifyMatchFailure(
loc, "incompatible memref type for EmitC conversion");
- }
- Type resultType = convertMemRefType(memrefType, getTypeConverter());
- if (!resultType) {
- return rewriter.notifyMatchFailure(loc, "cannot convert result type");
- }
-
- auto baseptr =
- cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
- auto emitcType = convertMemRefType(baseptr, getTypeConverter());
- auto arrT = emitc::ArrayType::get(memrefType.getShape(), emitcType);
- auto valVar = rewriter.create<emitc::VariableOp>(
- loc, arrT, emitc::OpaqueAttr::get(rewriter.getContext(), ""));
+ emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+ TypedValue<emitc::ArrayType> srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
+ srcArrayValue]() -> emitc::ApplyOp {
+ int64_t rank = srcArrayValue.getType().getRank();
+ llvm::SmallVector<mlir::Value> indices;
+ for (int i = 0; i < rank; ++i) {
+ indices.push_back(zeroIndex);
+ }
+
+ emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, srcArrayValue, mlir::ValueRange(indices));
+ emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+ loc,
+ emitc::PointerType::get(srcArrayValue.getType().getElementType()),
+ rewriter.getStringAttr("&"), subPtr);
+
+ return ptr;
+ };
+
+ emitc::ApplyOp srcPtr = createPointerFromEmitcArray();
auto [strides, offset] = memrefType.getStridesAndOffset();
Value offsetValue = rewriter.create<emitc::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
SmallVector<Value> results;
- results.push_back(valVar);
+ results.push_back(srcPtr);
results.push_back(offsetValue);
for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
@@ -343,39 +354,6 @@ struct ConvertExtractStridedMetadata final
}
};
-struct ConvertReinterpretCastOp
- : public OpConversionPattern<memref::ReinterpretCastOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
-
- MemRefType targetMemRefType =
- cast<MemRefType>(castOp.getResult().getType());
-
- auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
- auto targetInEmitC =
- convertMemRefType(targetMemRefType, getTypeConverter());
- if (!srcInEmitC || !targetInEmitC) {
- return rewriter.notifyMatchFailure(castOp.getLoc(),
- "cannot convert memref type");
- }
-
- // Create descriptor.
- Location loc = castOp.getLoc();
-
- auto vals = adaptor.getOperands();
-
- auto res =
- UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
- .getResult(0);
-
- return success();
- }
-};
-
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -409,7 +387,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
- ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertReinterpretCastOp, ConvertStore>(converter,
- patterns.getContext());
+ ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..d36eaf3c2673a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,19 @@ module @globals {
return
}
}
+
+// -----
+
+// CHECK-LABEL: reinterpret_cast
+func.func @reinterpret_cast(%arg18: memref<1xi32>) {
+ // CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32>
+ // CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+ // CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue<i32>
+ // CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+ // CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index
+ // CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index
+ // CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index
+ %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
+ return
+}
+
>From a8edc52cd07e10668bf73539a957fe55369f9cc8 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 21:32:29 +0000
Subject: [PATCH 4/5] restore variableop functionality
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 6 +++---
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d19a32aa39734..04349d6bafb85 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,10 +1191,10 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs Res<AnyTypeOf<[EmitCType]>,
+ let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>,
"", [MemAlloc<DefaultResource, 0, FullEffect>]>);
- // let hasVerifier = 1;
+ let hasVerifier = 1;
}
def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 87d6f713ea35a..4c0902293cbf9 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -889,9 +889,9 @@ LogicalResult SubOp::verify() {
// VariableOp
//===----------------------------------------------------------------------===//
-// LogicalResult emitc::VariableOp::verify() {
-// return verifyInitializationAttribute(getOperation(), getValueAttr());
-// }
+LogicalResult emitc::VariableOp::verify() {
+ return verifyInitializationAttribute(getOperation(), getValueAttr());
+}
//===----------------------------------------------------------------------===//
// YieldOp
>From ac631e8629cecfb5084e0957aa770afe5267948e Mon Sep 17 00:00:00 2001
From: Jaden Angella <ajaden at google.com>
Date: Tue, 5 Aug 2025 14:33:52 -0700
Subject: [PATCH 5/5] Update EmitC.td
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 04349d6bafb85..7fe2da8f7e044 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,8 +1191,8 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>,
- "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
+ let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
+ [MemAlloc<DefaultResource, 0, FullEffect>]>);
let hasVerifier = 1;
}
More information about the Mlir-commits
mailing list