[Mlir-commits] [mlir] Expand the MemRef to EmitC pass (PR #148055)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 10 16:48:08 PDT 2025
https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/148055
>From 127a1ed01365f40c719a430d3d6ae56390728bb3 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 7 Jul 2025 18:49:20 +0000
Subject: [PATCH 1/3] Initial work on memref ops
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 37 ++++++++++++++++++-
1 file changed, 36 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..742d2bfff27de 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,31 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ return failure();
+ }
+};
+
+static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
+ const TypeConverter &typeConverter) {
+ Type elementType = typeConverter.convertType(type.getElementType());
+ Type arrayTy = elementType;
+ // Shape has the outermost dim at index 0, so need to walk it backwards
+ auto shape = type.getShape();
+ if (shape.empty()) {
+ arrayTy = emitc::ArrayType::get({1}, arrayTy);
+ } else {
+ // For non-zero dimensions, use the original shape
+ arrayTy = emitc::ArrayType::get(shape, arrayTy);
+ }
+ return arrayTy;
+}
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -95,7 +120,8 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- auto resultTy = getTypeConverter()->convertType(op.getType());
+ auto resultTy =
+ convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
@@ -114,6 +140,15 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;
Attribute initialValue = operands.getInitialValueAttr();
+ if (op.getType().getRank() == 0) {
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+ auto scalarValue = elementsAttr.getSplatValue<Attribute>();
+
+ // Convert scalar value to single-element array
+ initialValue = DenseElementsAttr::get(
+ RankedTensorType::get({1}, elementsAttr.getElementType()),
+ {scalarValue});
+ }
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
>From c2372515aaeafba1d11293ecdc7cf013c130b9e5 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 10 Jul 2025 20:36:08 +0000
Subject: [PATCH 2/3] Convert scalars to constants
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 43 ++++++++-----------
1 file changed, 19 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 742d2bfff27de..f69a362395ef6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
}
};
-static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
- const TypeConverter &typeConverter) {
- Type elementType = typeConverter.convertType(type.getElementType());
- Type arrayTy = elementType;
- // Shape has the outermost dim at index 0, so need to walk it backwards
- auto shape = type.getShape();
- if (shape.empty()) {
- arrayTy = emitc::ArrayType::get({1}, arrayTy);
- } else {
- // For non-zero dimensions, use the original shape
- arrayTy = emitc::ArrayType::get(shape, arrayTy);
- }
- return arrayTy;
-}
-
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
-
+ auto type = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
@@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- auto resultTy =
- convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+ // auto resultTy =
+ // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+ Type resultTy;
+ Type elementType = getTypeConverter()->convertType(type.getElementType());
+ auto shape = type.getShape();
+
+ if (shape.empty()) {
+ if (emitc::isSupportedFloatType(elementType)) {
+ resultTy = rewriter.getF32Type();
+ }
+ if (emitc::isSupportedIntegerType(elementType)) {
+ resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
+ }
+ } else {
+ resultTy = emitc::ArrayType::get(shape, elementType);
+ }
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
@@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
Attribute initialValue = operands.getInitialValueAttr();
if (op.getType().getRank() == 0) {
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
- auto scalarValue = elementsAttr.getSplatValue<Attribute>();
-
- // Convert scalar value to single-element array
- initialValue = DenseElementsAttr::get(
- RankedTensorType::get({1}, elementsAttr.getElementType()),
- {scalarValue});
+ initialValue = elementsAttr.getSplatValue<Attribute>();
}
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
>From 36b61a6b5731a524b4e79e77d7505c7a5ef3d0f9 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 10 Jul 2025 23:47:46 +0000
Subject: [PATCH 3/3] global and getGlobal
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 48 ++++++++-----------
.../MemRefToEmitC/memref-to-emitc.mlir | 4 ++
2 files changed, 24 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index f69a362395ef6..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
-struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
- ConversionPatternRewriter &rewriter) const override {
- return failure();
- }
-};
-
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto type = op.getType();
+ MemRefType type = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
@@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- // auto resultTy =
- // convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+
Type resultTy;
- Type elementType = getTypeConverter()->convertType(type.getElementType());
- auto shape = type.getShape();
-
- if (shape.empty()) {
- if (emitc::isSupportedFloatType(elementType)) {
- resultTy = rewriter.getF32Type();
- }
- if (emitc::isSupportedIntegerType(elementType)) {
- resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
- }
- } else {
- resultTy = emitc::ArrayType::get(shape, elementType);
- }
+ if (type.getRank() == 0)
+ resultTy = getTypeConverter()->convertType(type.getElementType());
+ else
+ resultTy = getTypeConverter()->convertType(type);
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
@@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;
Attribute initialValue = operands.getInitialValueAttr();
- if (op.getType().getRank() == 0) {
+ if (type.getRank() == 0) {
auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
initialValue = elementsAttr.getSplatValue<Attribute>();
}
@@ -162,7 +144,17 @@ struct ConvertGetGlobal final
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto resultTy = getTypeConverter()->convertType(op.getType());
+ MemRefType type = op.getType();
+ Type resultTy;
+ if (type.getRank() == 0)
+ resultTy = emitc::LValueType::get(
+ getTypeConverter()->convertType(type.getElementType()));
+ else
+ resultTy = getTypeConverter()->convertType(type);
+
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+ // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
memref.global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
+ // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+ %1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
}
More information about the Mlir-commits
mailing list