[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `reinterpret_cast` (PR #152610)
Jaden Angella
llvmlistbot at llvm.org
Fri Aug 8 14:08:33 PDT 2025
https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/152610
>From c6ea99574fde268a83f20642f507aafe1ac6edf6 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 6 Aug 2025 20:18:38 +0000
Subject: [PATCH 1/3] needs improvment
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 84 ++++++++++++++++++-
1 file changed, 83 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..60cccb0ece8f2 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -21,7 +21,9 @@
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
+#include <string>
using namespace mlir;
@@ -269,6 +271,85 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}
};
+struct ConvertReinterpretCastOp final
+ : public OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
+
+ MemRefType targetMemRefType =
+ cast<MemRefType>(castOp.getResult().getType());
+
+ auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
+ auto targetInEmitC =
+ convertMemRefType(targetMemRefType, getTypeConverter());
+ if (!srcInEmitC || !targetInEmitC) {
+ return rewriter.notifyMatchFailure(castOp.getLoc(),
+ "cannot convert memref type");
+ }
+ Location loc = castOp.getLoc();
+
+ auto srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(adaptor.getSource());
+
+ emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+
+ auto createPointerFromEmitcArray =
+ [loc, &rewriter, &zeroIndex](
+ mlir::TypedValue<emitc::ArrayType> arrayValue) -> emitc::ApplyOp {
+ int64_t rank = arrayValue.getType().getRank();
+ llvm::SmallVector<mlir::Value> indices;
+ for (int i = 0; i < rank; ++i) {
+ indices.push_back(zeroIndex);
+ }
+
+ emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+ loc, arrayValue, mlir::ValueRange(indices));
+ emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+ loc, emitc::PointerType::get(arrayValue.getType().getElementType()),
+ rewriter.getStringAttr("&"), subPtr);
+
+ return ptr;
+ };
+ auto [strides, offset] = targetMemRefType.getStridesAndOffset();
+ // Value offsetValue = rewriter.create<emitc::ConstantOp>(
+ // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+ auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
+ // emitc::PointerType targetPointerType =
+ // emitc::PointerType::get(srcArrayValue.getType().getElementType());
+
+ auto dimensions = targetMemRefType.getShape();
+ std::string reinterpretCastName = llvm::formatv(
+ "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
+ std::string dimensionsStr;
+ for (auto dim : dimensions) {
+ dimensionsStr += llvm::formatv("[{0}]", dim);
+ }
+ reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
+ reinterpretCastName += ">";
+
+ reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
+
+ std::string outputStr = llvm::formatv(
+ "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
+ auto outputType = emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), outputStr));
+
+ emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
+ loc, outputType,
+ emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
+
+ rewriter.replaceOp(castOp, reinterpretOp.getResult());
+ return success();
+ }
+};
+
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -321,5 +402,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ ConvertLoad, ConvertReinterpretCastOp, ConvertStore>(
+ converter, patterns.getContext());
}
>From d09ccc0fbcf77ac09bafbc9b96b4bc6332c7f02a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 7 Aug 2025 23:06:46 +0000
Subject: [PATCH 2/3] almost functional option
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 41 ++++++++-----------
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 13 +++++-
.../memref-to-emitc-reinterpret-cast.mlir | 16 ++++++++
3 files changed, 43 insertions(+), 27 deletions(-)
create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 60cccb0ece8f2..97b926861482f 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
@@ -316,36 +317,26 @@ struct ConvertReinterpretCastOp final
return ptr;
};
- auto [strides, offset] = targetMemRefType.getStridesAndOffset();
- // Value offsetValue = rewriter.create<emitc::ConstantOp>(
- // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
- // emitc::PointerType targetPointerType =
- // emitc::PointerType::get(srcArrayValue.getType().getElementType());
-
- auto dimensions = targetMemRefType.getShape();
- std::string reinterpretCastName = llvm::formatv(
- "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType());
- std::string dimensionsStr;
- for (auto dim : dimensions) {
- dimensionsStr += llvm::formatv("[{0}]", dim);
- }
- reinterpretCastName += llvm::formatv("{0}>", dimensionsStr);
- reinterpretCastName += ">";
-
- reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0));
+ // 1. Create a TypeAttr for the target type.
+ TypeAttr targetTypeAttr =
+ TypeAttr::get(emitc::PointerType::get(targetInEmitC));
+ IntegerAttr resty = rewriter.getIndexAttr(0);
- std::string outputStr = llvm::formatv(
- "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr);
- auto outputType = emitc::PointerType::get(
- emitc::OpaqueType::get(rewriter.getContext(), outputStr));
+ // 2. Create an ArrayAttr with the TypeAttr. This will be the
+ // templateArgsAttr.
+ ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr});
- emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>(
- loc, outputType,
- emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName));
+ auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)},
+ /*callee=*/"reinterpret_cast",
+ /*args*/ rewriter.getArrayAttr({resty}),
+ /*template_args=*/templateArgsAttr,
+ /*operands=*/ValueRange{srcPtr.getResult()});
- rewriter.replaceOp(castOp, reinterpretOp.getResult());
+ rewriter.replaceOp(castOp, reinterpretCastCall.getResults());
return success();
}
};
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8e83e455d1a7f..3d47da8c439de 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1848,8 +1848,17 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto lType = dyn_cast<emitc::LValueType>(type))
return emitType(loc, lType.getValueType());
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
- if (isa<ArrayType>(pType.getPointee()))
- return emitError(loc, "cannot emit pointer to array type ") << type;
+ // Check if the pointee is an array type.
+ if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) {
+ // Handle pointer to array: `element_type (*)[dim]`.
+ if (failed(emitType(loc, aType.getElementType())))
+ return failure();
+ os << "(*)";
+ for (auto dim : aType.getShape())
+ os << "[" << dim << "]";
+ return success();
+ }
+ // Handle standard pointer: `element_type*`.
if (failed(emitType(loc, pType.getPointee())))
return failure();
os << "*";
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir
new file mode 100644
index 0000000000000..5d610fe1ae642
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-reinterpret-cast.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+func.func @casting(%arg0: memref<999xi32>) {
+ %reinterpret_cast_5 = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1, 1, 999], strides: [999, 999, 1] : memref<999xi32> to memref<1x1x999xi32>
+ return
+}
+
+//CHECK: module {
+//CHECK-NEXT: func.func @casting(%arg0: memref<999xi32>) {
+//CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<999xi32> to !emitc.array<999xi32>
+//CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+//CHECK-NEXT: %2 = emitc.subscript %0[%1] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
+//CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+//CHECK-NEXT: %4 = emitc.call_opaque "reinterpret_cast"(%3) {args = [0 : index], template_args = [!emitc.ptr<!emitc.array<1x1x999xi32>>]} : (!emitc.ptr<i32>) -> !emitc.ptr<!emitc.array<1x1x999xi32>>
+//CHECK-NEXT: return
+
>From 19e23d560b00008e664c2e85df9840ae71f51ee1 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 8 Aug 2025 21:08:08 +0000
Subject: [PATCH 3/3] cpp output that compiles
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 19 +++----------------
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 14 ++++++++++++++
2 files changed, 17 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 97b926861482f..18bd79de2ec2a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -319,24 +319,11 @@ struct ConvertReinterpretCastOp final
};
auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
- // 1. Create a TypeAttr for the target type.
- TypeAttr targetTypeAttr =
- TypeAttr::get(emitc::PointerType::get(targetInEmitC));
- IntegerAttr resty = rewriter.getIndexAttr(0);
- // 2. Create an ArrayAttr with the TypeAttr. This will be the
- // templateArgsAttr.
- ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr});
+ auto castCall = rewriter.create<emitc::CastOp>(
+ loc, emitc::PointerType::get(targetInEmitC), srcPtr.getResult());
- auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>(
- loc,
- /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)},
- /*callee=*/"reinterpret_cast",
- /*args*/ rewriter.getArrayAttr({resty}),
- /*template_args=*/templateArgsAttr,
- /*operands=*/ValueRange{srcPtr.getResult()});
-
- rewriter.replaceOp(castOp, reinterpretCastCall.getResults());
+ rewriter.replaceOp(castOp, castCall);
return success();
}
};
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3d47da8c439de..6ec1c070fde83 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1756,6 +1756,20 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
StringRef name) {
+ if (auto pType = dyn_cast<emitc::PointerType>(type)) {
+ if (auto aType = dyn_cast<emitc::ArrayType>(pType.getPointee())) {
+ if (failed(emitType(loc, aType.getElementType())))
+ return failure();
+ os << " (*" << name << ")";
+ for (auto dim : aType.getShape())
+ os << "[" << dim << "]";
+ return success();
+ }
+ if (failed(emitType(loc, pType.getPointee())))
+ return failure();
+ os << " *" << name;
+ return success();
+ }
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, arrType.getElementType())))
return failure();
More information about the Mlir-commits
mailing list