[Mlir-commits] [mlir] 0aa6d57 - [MLIR] Add initial convert-memref-to-emitc pass (#85389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 21 06:27:41 PDT 2024


Author: Matthias Gehre
Date: 2024-03-21T14:27:37+01:00
New Revision: 0aa6d57e575dd920db81bef7ff509c4d3a9c6891

URL: https://github.com/llvm/llvm-project/commit/0aa6d57e575dd920db81bef7ff509c4d3a9c6891
DIFF: https://github.com/llvm/llvm-project/commit/0aa6d57e575dd920db81bef7ff509c4d3a9c6891.diff

LOG: [MLIR] Add initial convert-memref-to-emitc pass (#85389)

This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.

Added: 
    mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
    mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
    mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
    mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
    mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
    mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
    mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..734ffdba520c9f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,21 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                             TypeConverter &converter);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

diff  --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
new file mode 100644
index 00000000000000..4a63014c19ad0e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
@@ -0,0 +1,20 @@
+//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..2179ae18ac074b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+  let summary = "Convert MemRef dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(NVGPUToNVVM)

diff  --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..8a72e747d024ae
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+  MemRefToEmitC.cpp
+  MemRefToEmitCPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIREmitCDialect
+  MLIRMemRefDialect
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..0e3b6469212640
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,114 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getType().hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with dynamic shape");
+    }
+
+    if (op.getAlignment().value_or(1) > 1) {
+      // TODO: Allow alignment if it is not more than the natural alignment
+      // of the C array.
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with alignment requirement");
+    }
+
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    if (!resultTy) {
+      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+    }
+    auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+    rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+    return success();
+  }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    if (!resultTy) {
+      return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+    }
+
+    auto subscript = rewriter.create<emitc::SubscriptOp>(
+        op.getLoc(), operands.getMemref(), operands.getIndices());
+
+    auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+    auto var =
+        rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
+
+    rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
+    rewriter.replaceOp(op, var);
+    return success();
+  }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto subscript = rewriter.create<emitc::SubscriptOp>(
+        op.getLoc(), operands.getMemref(), operands.getIndices());
+    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                 operands.getValue());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+  typeConverter.addConversion(
+      [&](MemRefType memRefType) -> std::optional<Type> {
+        if (!memRefType.hasStaticShape() ||
+            !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
+          return {};
+        }
+        Type convertedElementType =
+            typeConverter.convertType(memRefType.getElementType());
+        if (!convertedElementType)
+          return {};
+        return emitc::ArrayType::get(memRefType.getShape(),
+                                     convertedElementType);
+      });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                                   TypeConverter &converter) {
+  patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+                                                         patterns.getContext());
+}

diff  --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
new file mode 100644
index 00000000000000..4e5d1912d15729
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -0,0 +1,55 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertMemRefToEmitCPass
+    : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+  void runOnOperation() override {
+    TypeConverter converter;
+
+    // Fallback for other types.
+    converter.addConversion([](Type type) -> std::optional<Type> {
+      if (isa<MemRefType>(type))
+        return {};
+      return type;
+    });
+
+    populateMemRefToEmitCTypeConversion(converter);
+
+    RewritePatternSet patterns(&getContext());
+    populateMemRefToEmitCConversionPatterns(patterns, converter);
+
+    ConversionTarget target(getContext());
+    target.addIllegalDialect<memref::MemRefDialect>();
+    target.addLegalDialect<emitc::EmitCDialect>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace

diff  --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
new file mode 100644
index 00000000000000..390190d341e5ae
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+  // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+  %0 = index.constant 1
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+  %1 = memref.alloca(%0) : memref<4x?xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+  %0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+  return
+}
+
+// -----
+
+func.func @non_identity_layout() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+  %0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+  return
+}
+
+// -----
+
+func.func @zero_rank() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+  %0 = memref.alloca() : memref<f32>
+  return
+}

diff  --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..9793b2d6d7832f
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_store
+// CHECK-SAME:  %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_store(%v : f32, %i: index, %j: index) {
+  // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+  %0 = memref.alloca() : memref<4x8xf32>
+
+  // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
+  memref.store %v, %0[%i, %j] : memref<4x8xf32>
+  return
+}
+// -----
+
+// CHECK-LABEL: memref_load
+// CHECK-SAME:  %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load(%i: index, %j: index) -> f32 {
+  // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+  %0 = memref.alloca() : memref<4x8xf32>
+
+  // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+  // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
+  %1 = memref.load %0[%i, %j] : memref<4x8xf32>
+  // CHECK: return %[[VAR]] : f32
+  return %1 : f32
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 201c7f653398ae..3b575d4a413c31 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4186,6 +4186,7 @@ cc_library(
         ":MathToLLVM",
         ":MathToLibm",
         ":MathToSPIRV",
+        ":MemRefToEmitC",
         ":MemRefToLLVM",
         ":MemRefToSPIRV",
         ":NVGPUToNVVM",
@@ -8256,6 +8257,32 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "MemRefToEmitC",
+    srcs = glob([
+        "lib/Conversion/MemRefToEmitC/*.cpp",
+        "lib/Conversion/MemRefToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/MemRefToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/MemRefToEmitC",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":MemRefDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "MemRefToLLVM",
     srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),


        


More information about the Mlir-commits mailing list