[Mlir-commits] [mlir] Expand the MemRef to EmitC pass (PR #148055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 10 16:48:08 PDT 2025


https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/148055

>From 127a1ed01365f40c719a430d3d6ae56390728bb3 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 7 Jul 2025 18:49:20 +0000
Subject: [PATCH 1/3] Initial work on memref ops

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 37 ++++++++++++++++++-
 1 file changed, 36 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..742d2bfff27de 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,31 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
+                                           const TypeConverter &typeConverter) {
+  Type elementType = typeConverter.convertType(type.getElementType());
+  Type arrayTy = elementType;
+  // Shape has the outermost dim at index 0, so need to walk it backwards
+  auto shape = type.getShape();
+  if (shape.empty()) {
+    arrayTy = emitc::ArrayType::get({1}, arrayTy);
+  } else {
+    // For non-zero dimensions, use the original shape
+    arrayTy = emitc::ArrayType::get(shape, arrayTy);
+  }
+  return arrayTy;
+}
+
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -95,7 +120,8 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
           op.getLoc(), "global variable with alignment requirement is "
                        "currently not supported");
     }
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+    auto resultTy =
+        convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
@@ -114,6 +140,15 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     bool externSpecifier = !staticSpecifier;
 
     Attribute initialValue = operands.getInitialValueAttr();
+    if (op.getType().getRank() == 0) {
+      auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+      auto scalarValue = elementsAttr.getSplatValue<Attribute>();
+
+      // Convert scalar value to single-element array
+      initialValue = DenseElementsAttr::get(
+          RankedTensorType::get({1}, elementsAttr.getElementType()),
+          {scalarValue});
+    }
     if (isa_and_present<UnitAttr>(initialValue))
       initialValue = {};
 

>From c2372515aaeafba1d11293ecdc7cf013c130b9e5 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 10 Jul 2025 20:36:08 +0000
Subject: [PATCH 2/3] Convert scalars to constants

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 43 ++++++++-----------
 1 file changed, 19 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 742d2bfff27de..f69a362395ef6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -87,28 +87,13 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
   }
 };
 
-static Type convertGlobalMemrefTypeToEmitc(MemRefType type,
-                                           const TypeConverter &typeConverter) {
-  Type elementType = typeConverter.convertType(type.getElementType());
-  Type arrayTy = elementType;
-  // Shape has the outermost dim at index 0, so need to walk it backwards
-  auto shape = type.getShape();
-  if (shape.empty()) {
-    arrayTy = emitc::ArrayType::get({1}, arrayTy);
-  } else {
-    // For non-zero dimensions, use the original shape
-    arrayTy = emitc::ArrayType::get(shape, arrayTy);
-  }
-  return arrayTy;
-}
-
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-
+    auto type = op.getType();
     if (!op.getType().hasStaticShape()) {
       return rewriter.notifyMatchFailure(
           op.getLoc(), "cannot transform global with dynamic shape");
@@ -120,8 +105,23 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
           op.getLoc(), "global variable with alignment requirement is "
                        "currently not supported");
     }
-    auto resultTy =
-        convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+    // auto resultTy =
+    //     convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+    Type resultTy;
+    Type elementType = getTypeConverter()->convertType(type.getElementType());
+    auto shape = type.getShape();
+
+    if (shape.empty()) {
+      if (emitc::isSupportedFloatType(elementType)) {
+        resultTy = rewriter.getF32Type();
+      }
+      if (emitc::isSupportedIntegerType(elementType)) {
+        resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
+      }
+    } else {
+      resultTy = emitc::ArrayType::get(shape, elementType);
+    }
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
@@ -142,12 +142,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     Attribute initialValue = operands.getInitialValueAttr();
     if (op.getType().getRank() == 0) {
       auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
-      auto scalarValue = elementsAttr.getSplatValue<Attribute>();
-
-      // Convert scalar value to single-element array
-      initialValue = DenseElementsAttr::get(
-          RankedTensorType::get({1}, elementsAttr.getElementType()),
-          {scalarValue});
+      initialValue = elementsAttr.getSplatValue<Attribute>();
     }
     if (isa_and_present<UnitAttr>(initialValue))
       initialValue = {};

>From 36b61a6b5731a524b4e79e77d7505c7a5ef3d0f9 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 10 Jul 2025 23:47:46 +0000
Subject: [PATCH 3/3] global and getGlobal

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 48 ++++++++-----------
 .../MemRefToEmitC/memref-to-emitc.mlir        |  4 ++
 2 files changed, 24 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index f69a362395ef6..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -77,23 +79,13 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
-struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(memref::CopyOp op, OpAdaptor operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    return failure();
-  }
-};
-
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto type = op.getType();
+    MemRefType type = op.getType();
     if (!op.getType().hasStaticShape()) {
       return rewriter.notifyMatchFailure(
           op.getLoc(), "cannot transform global with dynamic shape");
@@ -105,22 +97,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
           op.getLoc(), "global variable with alignment requirement is "
                        "currently not supported");
     }
-    // auto resultTy =
-    //     convertGlobalMemrefTypeToEmitc(op.getType(), *getTypeConverter());
+
     Type resultTy;
-    Type elementType = getTypeConverter()->convertType(type.getElementType());
-    auto shape = type.getShape();
-
-    if (shape.empty()) {
-      if (emitc::isSupportedFloatType(elementType)) {
-        resultTy = rewriter.getF32Type();
-      }
-      if (emitc::isSupportedIntegerType(elementType)) {
-        resultTy = rewriter.getIntegerType(elementType.getIntOrFloatBitWidth());
-      }
-    } else {
-      resultTy = emitc::ArrayType::get(shape, elementType);
-    }
+    if (type.getRank() == 0)
+      resultTy = getTypeConverter()->convertType(type.getElementType());
+    else
+      resultTy = getTypeConverter()->convertType(type);
 
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
@@ -140,7 +122,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     bool externSpecifier = !staticSpecifier;
 
     Attribute initialValue = operands.getInitialValueAttr();
-    if (op.getType().getRank() == 0) {
+    if (type.getRank() == 0) {
       auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
       initialValue = elementsAttr.getSplatValue<Attribute>();
     }
@@ -162,7 +144,17 @@ struct ConvertGetGlobal final
   matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto resultTy = getTypeConverter()->convertType(op.getType());
+    MemRefType type = op.getType();
+    Type resultTy;
+    if (type.getRank() == 0)
+      resultTy = emitc::LValueType::get(
+          getTypeConverter()->convertType(type.getElementType()));
+    else
+      resultTy = getTypeConverter()->convertType(type);
+
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
     if (!resultTy) {
       return rewriter.notifyMatchFailure(op.getLoc(),
                                          "cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
 module @globals {
   memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
   // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+  memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+  // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
   memref.global @public_global : memref<3x7xf32>
   // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
   memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
   func.func @use_global() {
     // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
     %0 = memref.get_global @public_global : memref<3x7xf32>
+    // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+    %1 = memref.get_global @__constant_xi32 : memref<i32>
     return
   }
 }



More information about the Mlir-commits mailing list