[Mlir-commits] [mlir] [mlir][EmitC] Expand the MemRefToEmitC pass - Lowering `extract_strided_metadata` (PR #152208)

Jaden Angella llvmlistbot at llvm.org
Tue Aug 5 14:34:53 PDT 2025


https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/152208

>From 6752ed3f1a051b172a5ffe9d0fae8276a741a14e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 31 Jul 2025 22:43:13 +0000
Subject: [PATCH 1/5] initial work on metadata ops

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 92 ++++++++++++++++++-
 1 file changed, 90 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..1008fefc65cf0 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
 #include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <cstdint>
 
@@ -288,6 +290,90 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
     return success();
   }
 };
+
+struct ConvertExtractStridedMetadata final
+    : public OpConversionPattern<memref::ExtractStridedMetadataOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+                  OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = extractStridedMetadataOp.getLoc();
+    Value source = extractStridedMetadataOp.getSource();
+
+    MemRefType memrefType = cast<MemRefType>(source.getType());
+    if (!isMemRefTypeLegalForEmitC(memrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible memref type for EmitC conversion");
+    }
+
+    Type resultType = convertMemRefType(memrefType, getTypeConverter());
+    if (!resultType) {
+      return rewriter.notifyMatchFailure(loc, "cannot convert result type");
+    }
+
+    auto baseptr =
+        cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
+    auto emitcType = convertMemRefType(baseptr, getTypeConverter());
+
+    auto [strides, offset] = memrefType.getStridesAndOffset();
+    Value offsetValue = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
+
+    SmallVector<Value> results;
+    results.push_back(extractStridedMetadataOp.getBaseBuffer());
+    results.push_back(offsetValue);
+
+    for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
+      Value sizeValue = rewriter.create<emitc::ConstantOp>(
+          loc, rewriter.getIndexType(),
+          rewriter.getIndexAttr(memrefType.getDimSize(i)));
+      results.push_back(sizeValue);
+
+      Value strideValue = rewriter.create<emitc::ConstantOp>(
+          loc, rewriter.getIndexType(), rewriter.getIndexAttr(strides[i]));
+      results.push_back(strideValue);
+    }
+
+    rewriter.replaceOp(extractStridedMetadataOp, results);
+    return success();
+  }
+};
+
+struct ConvertReinterpretCastOp
+    : public OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
+
+    MemRefType targetMemRefType =
+        cast<MemRefType>(castOp.getResult().getType());
+
+    auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
+    auto targetInEmitC =
+        convertMemRefType(targetMemRefType, getTypeConverter());
+    if (!srcInEmitC || !targetInEmitC) {
+      return rewriter.notifyMatchFailure(castOp.getLoc(),
+                                         "cannot convert memref type");
+    }
+
+    // Create descriptor.
+    Location loc = castOp.getLoc();
+
+    auto vals = adaptor.getOperands();
+
+    auto res =
+        UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
+            .getResult(0);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -320,6 +406,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
+               ConvertGlobal, ConvertGetGlobal, ConvertLoad,
+               ConvertReinterpretCastOp, ConvertStore>(converter,
+                                                       patterns.getContext());
 }

>From 57b34ae74ff63eebfb3eadb596f983cad0b0e683 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 1 Aug 2025 18:25:00 +0000
Subject: [PATCH 2/5] setting up variables

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td         | 6 +++---
 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 6 ++++--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp                 | 6 +++---
 3 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7fe2da8f7e044..d19a32aa39734 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,10 +1191,10 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
   }];
 
   let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
-  let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
-                          [MemAlloc<DefaultResource, 0, FullEffect>]>);
+  let results = (outs Res<AnyTypeOf<[EmitCType]>,
+                          "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
 
-  let hasVerifier = 1;
+  // let hasVerifier = 1;
 }
 
 def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 1008fefc65cf0..d836fa0066b7c 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -316,13 +316,15 @@ struct ConvertExtractStridedMetadata final
     auto baseptr =
         cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
     auto emitcType = convertMemRefType(baseptr, getTypeConverter());
-
+    auto arrT = emitc::ArrayType::get(memrefType.getShape(), emitcType);
+    auto valVar = rewriter.create<emitc::VariableOp>(
+        loc, arrT, emitc::OpaqueAttr::get(rewriter.getContext(), ""));
     auto [strides, offset] = memrefType.getStridesAndOffset();
     Value offsetValue = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
 
     SmallVector<Value> results;
-    results.push_back(extractStridedMetadataOp.getBaseBuffer());
+    results.push_back(valVar);
     results.push_back(offsetValue);
 
     for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c0902293cbf9..87d6f713ea35a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -889,9 +889,9 @@ LogicalResult SubOp::verify() {
 // VariableOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult emitc::VariableOp::verify() {
-  return verifyInitializationAttribute(getOperation(), getValueAttr());
-}
+// LogicalResult emitc::VariableOp::verify() {
+//   return verifyInitializationAttribute(getOperation(), getValueAttr());
+// }
 
 //===----------------------------------------------------------------------===//
 // YieldOp

>From 2147fca8a77d0ede8f13e9027357992643512e3e Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 21:00:06 +0000
Subject: [PATCH 3/5] separate the ops

---
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 77 +++++++------------
 .../MemRefToEmitC/memref-to-emitc.mlir        | 16 ++++
 2 files changed, 43 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index d836fa0066b7c..428cdb0c1425a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -303,28 +303,39 @@ struct ConvertExtractStridedMetadata final
     Value source = extractStridedMetadataOp.getSource();
 
     MemRefType memrefType = cast<MemRefType>(source.getType());
-    if (!isMemRefTypeLegalForEmitC(memrefType)) {
+    if (!isMemRefTypeLegalForEmitC(memrefType))
       return rewriter.notifyMatchFailure(
           loc, "incompatible memref type for EmitC conversion");
-    }
 
-    Type resultType = convertMemRefType(memrefType, getTypeConverter());
-    if (!resultType) {
-      return rewriter.notifyMatchFailure(loc, "cannot convert result type");
-    }
-
-    auto baseptr =
-        cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType());
-    auto emitcType = convertMemRefType(baseptr, getTypeConverter());
-    auto arrT = emitc::ArrayType::get(memrefType.getShape(), emitcType);
-    auto valVar = rewriter.create<emitc::VariableOp>(
-        loc, arrT, emitc::OpaqueAttr::get(rewriter.getContext(), ""));
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+    TypedValue<emitc::ArrayType> srcArrayValue =
+        cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+    auto createPointerFromEmitcArray = [loc, &rewriter, &zeroIndex,
+                                        srcArrayValue]() -> emitc::ApplyOp {
+      int64_t rank = srcArrayValue.getType().getRank();
+      llvm::SmallVector<mlir::Value> indices;
+      for (int i = 0; i < rank; ++i) {
+        indices.push_back(zeroIndex);
+      }
+
+      emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
+          loc, srcArrayValue, mlir::ValueRange(indices));
+      emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
+          loc,
+          emitc::PointerType::get(srcArrayValue.getType().getElementType()),
+          rewriter.getStringAttr("&"), subPtr);
+
+      return ptr;
+    };
+
+    emitc::ApplyOp srcPtr = createPointerFromEmitcArray();
     auto [strides, offset] = memrefType.getStridesAndOffset();
     Value offsetValue = rewriter.create<emitc::ConstantOp>(
         loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset));
 
     SmallVector<Value> results;
-    results.push_back(valVar);
+    results.push_back(srcPtr);
     results.push_back(offsetValue);
 
     for (unsigned i = 0, e = memrefType.getRank(); i < e; ++i) {
@@ -343,39 +354,6 @@ struct ConvertExtractStridedMetadata final
   }
 };
 
-struct ConvertReinterpretCastOp
-    : public OpConversionPattern<memref::ReinterpretCastOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    MemRefType srcType = cast<MemRefType>(castOp.getSource().getType());
-
-    MemRefType targetMemRefType =
-        cast<MemRefType>(castOp.getResult().getType());
-
-    auto srcInEmitC = convertMemRefType(srcType, getTypeConverter());
-    auto targetInEmitC =
-        convertMemRefType(targetMemRefType, getTypeConverter());
-    if (!srcInEmitC || !targetInEmitC) {
-      return rewriter.notifyMatchFailure(castOp.getLoc(),
-                                         "cannot convert memref type");
-    }
-
-    // Create descriptor.
-    Location loc = castOp.getLoc();
-
-    auto vals = adaptor.getOperands();
-
-    auto res =
-        UnrealizedConversionCastOp::create(rewriter, loc, targetInEmitC, vals)
-            .getResult(0);
-
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
@@ -409,7 +387,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
   patterns.add<ConvertAlloca, ConvertAlloc, ConvertExtractStridedMetadata,
-               ConvertGlobal, ConvertGetGlobal, ConvertLoad,
-               ConvertReinterpretCastOp, ConvertStore>(converter,
-                                                       patterns.getContext());
+               ConvertGlobal, ConvertGetGlobal, ConvertLoad, ConvertStore>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..d36eaf3c2673a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,19 @@ module @globals {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: reinterpret_cast
+func.func @reinterpret_cast(%arg18: memref<1xi32>) {
+  // CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : memref<1xi32> to !emitc.array<1xi32>
+  // CHECK: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+  // CHECK: %2 = emitc.subscript %0[%1] : (!emitc.array<1xi32>, index) -> !emitc.lvalue<i32>
+  // CHECK: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
+  // CHECK: %4 = "emitc.constant"() <{value = 0 : index}> : () -> index
+  // CHECK: %5 = "emitc.constant"() <{value = 1 : index}> : () -> index
+  // CHECK: %6 = "emitc.constant"() <{value = 1 : index}> : () -> index
+  %base_buffer_485, %offset_486, %sizes_487, %strides_488 = memref.extract_strided_metadata %arg18 : memref<1xi32> -> memref<i32>, index, index, index
+  return
+}
+

>From a8edc52cd07e10668bf73539a957fe55369f9cc8 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 5 Aug 2025 21:32:29 +0000
Subject: [PATCH 4/5] restore variableop functionality

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 6 +++---
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d19a32aa39734..04349d6bafb85 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,10 +1191,10 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
   }];
 
   let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
-  let results = (outs Res<AnyTypeOf<[EmitCType]>,
+  let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>,
                           "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
 
-  // let hasVerifier = 1;
+  let hasVerifier = 1;
 }
 
 def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 87d6f713ea35a..4c0902293cbf9 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -889,9 +889,9 @@ LogicalResult SubOp::verify() {
 // VariableOp
 //===----------------------------------------------------------------------===//
 
-// LogicalResult emitc::VariableOp::verify() {
-//   return verifyInitializationAttribute(getOperation(), getValueAttr());
-// }
+LogicalResult emitc::VariableOp::verify() {
+  return verifyInitializationAttribute(getOperation(), getValueAttr());
+}
 
 //===----------------------------------------------------------------------===//
 // YieldOp

>From ac631e8629cecfb5084e0957aa770afe5267948e Mon Sep 17 00:00:00 2001
From: Jaden Angella <ajaden at google.com>
Date: Tue, 5 Aug 2025 14:33:52 -0700
Subject: [PATCH 5/5] Update EmitC.td

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 04349d6bafb85..7fe2da8f7e044 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1191,8 +1191,8 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
   }];
 
   let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
-  let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>,
-                          "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
+  let results = (outs Res<AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>, "",
+                          [MemAlloc<DefaultResource, 0, FullEffect>]>);
 
   let hasVerifier = 1;
 }



More information about the Mlir-commits mailing list