[llvm] [mlir] [MLIR] Add initial convert-memref-to-emitc pass (PR #85389)

Matthias Gehre via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 08:06:41 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/85389

>From 88cdc80038074b003269909868ef4072b1df88a5 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 15 Mar 2024 13:27:01 +0100
Subject: [PATCH 1/3] [MLIR] Add initial convert-memref-to-emitc pass

This translates memref types in func.func, func.call and func.return
to emitc.array
and it translates memref.alloca, memref.load & memref.store to
emitc.variable, emitc.subscipt and emitc.assign.
---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |  29 ++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |   9 +
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../Conversion/MemRefToEmitC/CMakeLists.txt   |  19 ++
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 163 ++++++++++++++++++
 .../MemRefToEmitC/memref-to-emit-failed.mlir  |  40 +++++
 .../MemRefToEmitC/memref-to-emitc.mlir        |  47 +++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  29 ++++
 9 files changed, 338 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
 create mode 100644 mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
 create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..9815267a2f28b8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,29 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                             TypeConverter &converter);
+
+std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..8dde20099e7aa3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+  let summary = "Convert MemRef dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
 add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
 add_subdirectory(MemRefToSPIRV)
 add_subdirectory(NVGPUToNVVM)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..7bcec4cbadfce4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+  MemRefToEmitC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIREmitCDialect
+  MLIRFuncDialect
+  MLIRFuncTransforms
+  MLIRMemRefDialect
+  MLIRTransforms
+  )
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..b40bf8e9bf0df8
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,163 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+  return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+  return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+  return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getType().hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with dynamic shape");
+    }
+
+    if (op.getAlignment().value_or(1) > 1) {
+      // TODO: Allow alignment if it is not more than the natural alignment
+      // of the C array.
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform alloca with alignment requirement");
+    }
+
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+    rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+    return success();
+  }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
+                                                    operands.getIndices());
+    return success();
+  }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto subscript = rewriter.create<emitc::SubscriptOp>(
+        op.getLoc(), operands.getMemref(), operands.getIndices());
+    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+                                                 operands.getValue());
+    return success();
+  }
+};
+
+struct ConvertMemRefToEmitCPass
+    : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+  void runOnOperation() override {
+    TypeConverter converter;
+    // Fallback for other types.
+    converter.addConversion([](Type type) { return type; });
+    populateMemRefToEmitCTypeConversion(converter);
+    converter.addConversion(
+        [&converter](FunctionType ty) -> std::optional<Type> {
+          SmallVector<Type> inputs;
+          if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+            return std::nullopt;
+
+          SmallVector<Type> results;
+          if (failed(converter.convertTypes(ty.getResults(), results)))
+            return std::nullopt;
+
+          return FunctionType::get(ty.getContext(), inputs, results);
+        });
+
+    RewritePatternSet patterns(&getContext());
+    populateMemRefToEmitCConversionPatterns(patterns, converter);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+
+    ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+    target.addDynamicallyLegalDialect<func::FuncDialect>(
+        [](Operation *op) { return isLegal(op); });
+    target.addIllegalDialect<memref::MemRefDialect>();
+    target.addLegalDialect<emitc::EmitCDialect>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+  typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
+    if (memRefType.hasStaticShape()) {
+      return emitc::ArrayType::get(memRefType.getShape(),
+                                   memRefType.getElementType());
+    }
+    return {};
+  });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+                                                   TypeConverter &converter) {
+  patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+                                                         patterns.getContext());
+}
+
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
+  return std::make_unique<ConvertMemRefToEmitCPass>();
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
new file mode 100644
index 00000000000000..0d4e4139b85fb4
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+// Unranked memrefs are not converted
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_unranked(%arg0 : memref<*xf32>) {
+  return
+}
+
+// -----
+
+// Memrefs with dynamic shapes are not converted
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
+  return
+}
+
+// -----
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+  // expected-error at +1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+  %0 = index.constant 1
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+  %1 = memref.alloca(%0) : memref<4x?xf32>
+  return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+  %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+  return
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..8a11dcf7603c62
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_arg
+// CHECK-SAME:  !emitc.array<32xf32>)
+func.func @memref_arg(%arg0 : memref<32xf32>) {
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: memref_return
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
+func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
+// CHECK: return %[[arg0]] : !emitc.array<32xf32>
+  func.return %arg0 : memref<32xf32>
+}
+
+// CHECK-LABEL: memref_call
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<32xf32>)
+func.func @memref_call(%arg0 : memref<32xf32>) {
+// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
+  func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: alloca
+func.func @alloca() {
+  // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+  %0 = memref.alloca() : memref<4x8xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: memref_load_store
+// CHECK-SAME:  %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
+// CHECK-SAME:  %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
+  // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
+  %0 = memref.load %in[%i, %j] : memref<4x8xf32>
+  // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
+  // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
+  memref.store %0, %out[%i, %j] : memref<3x5xf32>
+  return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a3243abde5bbc5..2582fd78cd6c81 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4168,6 +4168,7 @@ cc_library(
         ":MathToLLVM",
         ":MathToLibm",
         ":MathToSPIRV",
+        ":MemRefToEmitC",
         ":MemRefToLLVM",
         ":MemRefToSPIRV",
         ":NVGPUToNVVM",
@@ -8180,6 +8181,34 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "MemRefToEmitC",
+    srcs = glob([
+        "lib/Conversion/MemRefToEmitC/*.cpp",
+        "lib/Conversion/MemRefToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/MemRefToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/MemRefToEmitC",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":FuncDialect",
+        ":FuncTransforms",
+        ":MemRefDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "MemRefToLLVM",
     srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),

>From 62ad8c029b3f82f892416c616e926471491e6ca2 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 15 Mar 2024 15:09:01 +0100
Subject: [PATCH 2/3] Remove redundant pass creator

---
 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h | 3 ---
 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp        | 4 ----
 2 files changed, 7 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 9815267a2f28b8..762c8e67980ba4 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -21,9 +21,6 @@ void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
 
 void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
                                              TypeConverter &converter);
-
-std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index b40bf8e9bf0df8..d2d2e85d1984aa 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -157,7 +157,3 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
   patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
                                                          patterns.getContext());
 }
-
-std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
-  return std::make_unique<ConvertMemRefToEmitCPass>();
-}

>From 551fddeaaaf1671724adf8becda4d1b5e2e8a81d Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 18 Mar 2024 16:05:37 +0100
Subject: [PATCH 3/3] check for identity layout; split pass to separate source

---
 .../Conversion/MemRefToEmitC/MemRefToEmitC.h  |  5 -
 .../MemRefToEmitC/MemRefToEmitCPass.h         | 20 ++++
 mlir/include/mlir/Conversion/Passes.h         |  2 +-
 .../Conversion/MemRefToEmitC/CMakeLists.txt   |  1 +
 .../MemRefToEmitC/MemRefToEmitC.cpp           | 70 +-------------
 .../MemRefToEmitC/MemRefToEmitCPass.cpp       | 91 +++++++++++++++++++
 .../MemRefToEmitC/memref-to-emit-failed.mlir  |  8 ++
 7 files changed, 123 insertions(+), 74 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
 create mode 100644 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp

diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 762c8e67980ba4..734ffdba520c9f 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -8,15 +8,10 @@
 #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
 #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
 
-#include "mlir/Pass/Pass.h"
-
 namespace mlir {
 class RewritePatternSet;
 class TypeConverter;
 
-#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-
 void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
 
 void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
new file mode 100644
index 00000000000000..4a63014c19ad0e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
@@ -0,0 +1,20 @@
+//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8dde20099e7aa3..2179ae18ac074b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,7 +45,7 @@
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
-#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
index 7bcec4cbadfce4..372c5b89d83008 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRMemRefToEmitC
   MemRefToEmitC.cpp
+  MemRefToEmitCPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index d2d2e85d1984aa..0d8a7304354fe8 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -6,46 +6,21 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a pass to convert memref ops into emitc ops.
+// This file implements patterns to convert memref ops into emitc ops.
 //
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
 using namespace mlir;
 
 namespace {
-
-/// Disallow all memrefs even though we only have conversions
-/// for memrefs with static shape right now to have good diagnostics.
-bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
-
-template <typename RangeT>
-bool areLegal(RangeT &&range) {
-  return llvm::all_of(range, [](Type type) { return isLegal(type); });
-}
-
-bool isLegal(Operation *op) {
-  return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
-}
-
-bool isSignatureLegal(FunctionType ty) {
-  return areLegal(ty.getInputs()) && areLegal(ty.getResults());
-}
-
 struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -99,52 +74,11 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
     return success();
   }
 };
-
-struct ConvertMemRefToEmitCPass
-    : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
-  void runOnOperation() override {
-    TypeConverter converter;
-    // Fallback for other types.
-    converter.addConversion([](Type type) { return type; });
-    populateMemRefToEmitCTypeConversion(converter);
-    converter.addConversion(
-        [&converter](FunctionType ty) -> std::optional<Type> {
-          SmallVector<Type> inputs;
-          if (failed(converter.convertTypes(ty.getInputs(), inputs)))
-            return std::nullopt;
-
-          SmallVector<Type> results;
-          if (failed(converter.convertTypes(ty.getResults(), results)))
-            return std::nullopt;
-
-          return FunctionType::get(ty.getContext(), inputs, results);
-        });
-
-    RewritePatternSet patterns(&getContext());
-    populateMemRefToEmitCConversionPatterns(patterns, converter);
-    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
-                                                                   converter);
-    populateCallOpTypeConversionPattern(patterns, converter);
-    populateReturnOpTypeConversionPattern(patterns, converter);
-
-    ConversionTarget target(getContext());
-    target.addDynamicallyLegalOp<func::FuncOp>(
-        [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
-    target.addDynamicallyLegalDialect<func::FuncDialect>(
-        [](Operation *op) { return isLegal(op); });
-    target.addIllegalDialect<memref::MemRefDialect>();
-    target.addLegalDialect<emitc::EmitCDialect>();
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
 } // namespace
 
 void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
   typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
-    if (memRefType.hasStaticShape()) {
+    if (memRefType.hasStaticShape() && memRefType.getLayout().isIdentity()) {
       return emitc::ArrayType::get(memRefType.getShape(),
                                    memRefType.getElementType());
     }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
new file mode 100644
index 00000000000000..4233a98fbf8ea4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -0,0 +1,91 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+  return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+  return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+  return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertMemRefToEmitCPass
+    : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+  void runOnOperation() override {
+    TypeConverter converter;
+    // Fallback for other types.
+    converter.addConversion([](Type type) { return type; });
+    populateMemRefToEmitCTypeConversion(converter);
+    converter.addConversion(
+        [&converter](FunctionType ty) -> std::optional<Type> {
+          SmallVector<Type> inputs;
+          if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+            return std::nullopt;
+
+          SmallVector<Type> results;
+          if (failed(converter.convertTypes(ty.getResults(), results)))
+            return std::nullopt;
+
+          return FunctionType::get(ty.getContext(), inputs, results);
+        });
+
+    RewritePatternSet patterns(&getContext());
+    populateMemRefToEmitCConversionPatterns(patterns, converter);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+
+    ConversionTarget target(getContext());
+    target.addDynamicallyLegalOp<func::FuncOp>(
+        [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+    target.addDynamicallyLegalDialect<func::FuncDialect>(
+        [](Operation *op) { return isLegal(op); });
+    target.addIllegalDialect<memref::MemRefDialect>();
+    target.addLegalDialect<emitc::EmitCDialect>();
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
index 0d4e4139b85fb4..8d19cc452f9604 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -38,3 +38,11 @@ func.func @alloca_with_alignment() {
   %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
   return
 }
+
+// -----
+
+func.func @non_identity_layout() {
+  // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+  %1 = memref.alloca() {alignment = 64 : i64}: memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+  return
+}



More information about the llvm-commits mailing list