[Mlir-commits] [llvm] [mlir] [MLIR] Add initial convert-memref-to-emitc pass (PR #85389)
Matthias Gehre
llvmlistbot at llvm.org
Thu Mar 21 05:13:48 PDT 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/85389
>From 88cdc80038074b003269909868ef4072b1df88a5 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 15 Mar 2024 13:27:01 +0100
Subject: [PATCH 1/8] [MLIR] Add initial convert-memref-to-emitc pass
This translates memref types in func.func, func.call and func.return
to emitc.array
and it translates memref.alloca, memref.load & memref.store to
emitc.variable, emitc.subscipt and emitc.assign.
---
.../Conversion/MemRefToEmitC/MemRefToEmitC.h | 29 ++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 9 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/MemRefToEmitC/CMakeLists.txt | 19 ++
.../MemRefToEmitC/MemRefToEmitC.cpp | 163 ++++++++++++++++++
.../MemRefToEmitC/memref-to-emit-failed.mlir | 40 +++++
.../MemRefToEmitC/memref-to-emitc.mlir | 47 +++++
.../llvm-project-overlay/mlir/BUILD.bazel | 29 ++++
9 files changed, 338 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
create mode 100644 mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..9815267a2f28b8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,29 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter);
+
+std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..8dde20099e7aa3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+ let summary = "Convert MemRef dialect to EmitC dialect";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(NVGPUToNVVM)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..7bcec4cbadfce4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+ MemRefToEmitC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREmitCDialect
+ MLIRFuncDialect
+ MLIRFuncTransforms
+ MLIRMemRefDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..b40bf8e9bf0df8
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,163 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+ return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+ return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+ return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with dynamic shape");
+ }
+
+ if (op.getAlignment().value_or(1) > 1) {
+ // TODO: Allow alignment if it is not more than the natural alignment
+ // of the C array.
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with alignment requirement");
+ }
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+ rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+ return success();
+ }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
+ operands.getIndices());
+ return success();
+ }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto subscript = rewriter.create<emitc::SubscriptOp>(
+ op.getLoc(), operands.getMemref(), operands.getIndices());
+ rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+ operands.getValue());
+ return success();
+ }
+};
+
+struct ConvertMemRefToEmitCPass
+ : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ void runOnOperation() override {
+ TypeConverter converter;
+ // Fallback for other types.
+ converter.addConversion([](Type type) { return type; });
+ populateMemRefToEmitCTypeConversion(converter);
+ converter.addConversion(
+ [&converter](FunctionType ty) -> std::optional<Type> {
+ SmallVector<Type> inputs;
+ if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+ return std::nullopt;
+
+ SmallVector<Type> results;
+ if (failed(converter.convertTypes(ty.getResults(), results)))
+ return std::nullopt;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+
+ RewritePatternSet patterns(&getContext());
+ populateMemRefToEmitCConversionPatterns(patterns, converter);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+
+ ConversionTarget target(getContext());
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+ target.addDynamicallyLegalDialect<func::FuncDialect>(
+ [](Operation *op) { return isLegal(op); });
+ target.addIllegalDialect<memref::MemRefDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+ typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
+ if (memRefType.hasStaticShape()) {
+ return emitc::ArrayType::get(memRefType.getShape(),
+ memRefType.getElementType());
+ }
+ return {};
+ });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter) {
+ patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+ patterns.getContext());
+}
+
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
+ return std::make_unique<ConvertMemRefToEmitCPass>();
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
new file mode 100644
index 00000000000000..0d4e4139b85fb4
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+// Unranked memrefs are not converted
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_unranked(%arg0 : memref<*xf32>) {
+ return
+}
+
+// -----
+
+// Memrefs with dynamic shapes are not converted
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
+ return
+}
+
+// -----
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+ // expected-error at +1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+ memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+ %0 = index.constant 1
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ %1 = memref.alloca(%0) : memref<4x?xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+ return
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..8a11dcf7603c62
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_arg
+// CHECK-SAME: !emitc.array<32xf32>)
+func.func @memref_arg(%arg0 : memref<32xf32>) {
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: memref_return
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
+func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
+// CHECK: return %[[arg0]] : !emitc.array<32xf32>
+ func.return %arg0 : memref<32xf32>
+}
+
+// CHECK-LABEL: memref_call
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
+func.func @memref_call(%arg0 : memref<32xf32>) {
+// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
+ func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: alloca
+func.func @alloca() {
+ // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+ %0 = memref.alloca() : memref<4x8xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: memref_load_store
+// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
+// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
+ // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
+ %0 = memref.load %in[%i, %j] : memref<4x8xf32>
+ // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
+ // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
+ memref.store %0, %out[%i, %j] : memref<3x5xf32>
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a3243abde5bbc5..2582fd78cd6c81 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4168,6 +4168,7 @@ cc_library(
":MathToLLVM",
":MathToLibm",
":MathToSPIRV",
+ ":MemRefToEmitC",
":MemRefToLLVM",
":MemRefToSPIRV",
":NVGPUToNVVM",
@@ -8180,6 +8181,34 @@ cc_library(
],
)
+cc_library(
+ name = "MemRefToEmitC",
+ srcs = glob([
+ "lib/Conversion/MemRefToEmitC/*.cpp",
+ "lib/Conversion/MemRefToEmitC/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/MemRefToEmitC/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/MemRefToEmitC",
+ ],
+ deps = [
+ ":ConversionPassIncGen",
+ ":EmitCDialect",
+ ":FuncDialect",
+ ":FuncTransforms",
+ ":MemRefDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":TransformUtils",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
+
cc_library(
name = "MemRefToLLVM",
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),
>From 62ad8c029b3f82f892416c616e926471491e6ca2 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Fri, 15 Mar 2024 15:09:01 +0100
Subject: [PATCH 2/8] Remove redundant pass creator
---
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h | 3 ---
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 4 ----
2 files changed, 7 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 9815267a2f28b8..762c8e67980ba4 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -21,9 +21,6 @@ void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter);
-
-std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
-
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index b40bf8e9bf0df8..d2d2e85d1984aa 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -157,7 +157,3 @@ void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
}
-
-std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
- return std::make_unique<ConvertMemRefToEmitCPass>();
-}
>From 551fddeaaaf1671724adf8becda4d1b5e2e8a81d Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 18 Mar 2024 16:05:37 +0100
Subject: [PATCH 3/8] check for identity layout; split pass to separate source
---
.../Conversion/MemRefToEmitC/MemRefToEmitC.h | 5 -
.../MemRefToEmitC/MemRefToEmitCPass.h | 20 ++++
mlir/include/mlir/Conversion/Passes.h | 2 +-
.../Conversion/MemRefToEmitC/CMakeLists.txt | 1 +
.../MemRefToEmitC/MemRefToEmitC.cpp | 70 +-------------
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 91 +++++++++++++++++++
.../MemRefToEmitC/memref-to-emit-failed.mlir | 8 ++
7 files changed, 123 insertions(+), 74 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
create mode 100644 mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 762c8e67980ba4..734ffdba520c9f 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -8,15 +8,10 @@
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
-#include "mlir/Pass/Pass.h"
-
namespace mlir {
class RewritePatternSet;
class TypeConverter;
-#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
new file mode 100644
index 00000000000000..4a63014c19ad0e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
@@ -0,0 +1,20 @@
+//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8dde20099e7aa3..2179ae18ac074b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,7 +45,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
-#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
index 7bcec4cbadfce4..372c5b89d83008 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRMemRefToEmitC
MemRefToEmitC.cpp
+ MemRefToEmitCPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index d2d2e85d1984aa..0d8a7304354fe8 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -6,46 +6,21 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a pass to convert memref ops into emitc ops.
+// This file implements patterns to convert memref ops into emitc ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
using namespace mlir;
namespace {
-
-/// Disallow all memrefs even though we only have conversions
-/// for memrefs with static shape right now to have good diagnostics.
-bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
-
-template <typename RangeT>
-bool areLegal(RangeT &&range) {
- return llvm::all_of(range, [](Type type) { return isLegal(type); });
-}
-
-bool isLegal(Operation *op) {
- return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
-}
-
-bool isSignatureLegal(FunctionType ty) {
- return areLegal(ty.getInputs()) && areLegal(ty.getResults());
-}
-
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
using OpConversionPattern::OpConversionPattern;
@@ -99,52 +74,11 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};
-
-struct ConvertMemRefToEmitCPass
- : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
- void runOnOperation() override {
- TypeConverter converter;
- // Fallback for other types.
- converter.addConversion([](Type type) { return type; });
- populateMemRefToEmitCTypeConversion(converter);
- converter.addConversion(
- [&converter](FunctionType ty) -> std::optional<Type> {
- SmallVector<Type> inputs;
- if (failed(converter.convertTypes(ty.getInputs(), inputs)))
- return std::nullopt;
-
- SmallVector<Type> results;
- if (failed(converter.convertTypes(ty.getResults(), results)))
- return std::nullopt;
-
- return FunctionType::get(ty.getContext(), inputs, results);
- });
-
- RewritePatternSet patterns(&getContext());
- populateMemRefToEmitCConversionPatterns(patterns, converter);
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
- converter);
- populateCallOpTypeConversionPattern(patterns, converter);
- populateReturnOpTypeConversionPattern(patterns, converter);
-
- ConversionTarget target(getContext());
- target.addDynamicallyLegalOp<func::FuncOp>(
- [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
- target.addDynamicallyLegalDialect<func::FuncDialect>(
- [](Operation *op) { return isLegal(op); });
- target.addIllegalDialect<memref::MemRefDialect>();
- target.addLegalDialect<emitc::EmitCDialect>();
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
-};
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
- if (memRefType.hasStaticShape()) {
+ if (memRefType.hasStaticShape() && memRefType.getLayout().isIdentity()) {
return emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
new file mode 100644
index 00000000000000..4233a98fbf8ea4
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -0,0 +1,91 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Disallow all memrefs even though we only have conversions
+/// for memrefs with static shape right now to have good diagnostics.
+bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
+
+template <typename RangeT>
+bool areLegal(RangeT &&range) {
+ return llvm::all_of(range, [](Type type) { return isLegal(type); });
+}
+
+bool isLegal(Operation *op) {
+ return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
+}
+
+bool isSignatureLegal(FunctionType ty) {
+ return areLegal(ty.getInputs()) && areLegal(ty.getResults());
+}
+
+struct ConvertMemRefToEmitCPass
+ : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ void runOnOperation() override {
+ TypeConverter converter;
+ // Fallback for other types.
+ converter.addConversion([](Type type) { return type; });
+ populateMemRefToEmitCTypeConversion(converter);
+ converter.addConversion(
+ [&converter](FunctionType ty) -> std::optional<Type> {
+ SmallVector<Type> inputs;
+ if (failed(converter.convertTypes(ty.getInputs(), inputs)))
+ return std::nullopt;
+
+ SmallVector<Type> results;
+ if (failed(converter.convertTypes(ty.getResults(), results)))
+ return std::nullopt;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+
+ RewritePatternSet patterns(&getContext());
+ populateMemRefToEmitCConversionPatterns(patterns, converter);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+
+ ConversionTarget target(getContext());
+ target.addDynamicallyLegalOp<func::FuncOp>(
+ [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
+ target.addDynamicallyLegalDialect<func::FuncDialect>(
+ [](Operation *op) { return isLegal(op); });
+ target.addIllegalDialect<memref::MemRefDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
index 0d4e4139b85fb4..8d19cc452f9604 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -38,3 +38,11 @@ func.func @alloca_with_alignment() {
%1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
return
}
+
+// -----
+
+func.func @non_identity_layout() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ %1 = memref.alloca() {alignment = 64 : i64}: memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+ return
+}
>From 152eee3a2908c9fc7445416cf14eeba77d32ee68 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 18 Mar 2024 16:11:40 +0100
Subject: [PATCH 4/8] Change order of converter & patterns args to align with
rest of MLIR
---
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h | 4 ++--
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 734ffdba520c9f..3a4110a7b80687 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -14,8 +14,8 @@ class TypeConverter;
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
-void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
- TypeConverter &converter);
+void populateMemRefToEmitCConversionPatterns(TypeConverter &converter,
+ RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0d8a7304354fe8..670fe64fcff681 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -86,8 +86,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
});
}
-void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
- TypeConverter &converter) {
+void mlir::populateMemRefToEmitCConversionPatterns(
+ TypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
}
>From 170765f9a56cecfac7ae293362e9d178ae7406e2 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 19 Mar 2024 23:19:40 +0100
Subject: [PATCH 5/8] ensure that rank is greater than zero
---
.../mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h | 4 ++--
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp | 7 ++++---
.../MemRefToEmitC/memref-to-emit-failed.mlir | 12 +++++++++---
3 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index 3a4110a7b80687..734ffdba520c9f 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -14,8 +14,8 @@ class TypeConverter;
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
-void populateMemRefToEmitCConversionPatterns(TypeConverter &converter,
- RewritePatternSet &patterns);
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter);
} // namespace mlir
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 670fe64fcff681..e77c15636318b0 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -78,7 +78,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
- if (memRefType.hasStaticShape() && memRefType.getLayout().isIdentity()) {
+ if (memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() > 0) {
return emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
}
@@ -86,8 +87,8 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
});
}
-void mlir::populateMemRefToEmitCConversionPatterns(
- TypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
index 8d19cc452f9604..4c728299186853 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -41,8 +41,14 @@ func.func @alloca_with_alignment() {
// -----
-func.func @non_identity_layout() {
- // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
- %1 = memref.alloca() {alignment = 64 : i64}: memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @non_identity_layout(%arg0 : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>) {
+ return
+}
+
+// -----
+
+// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @zero_rank(%arg0 : memref<f32>) {
return
}
>From 3fba2faaedb5a608ae3afa31b0bd51fd71f89dbb Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Wed, 20 Mar 2024 16:02:13 +0100
Subject: [PATCH 6/8] Perform type conversion recursively
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 21 ++++++++++++-------
1 file changed, 13 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e77c15636318b0..c7215716449aab 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,14 +77,19 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
} // namespace
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
- typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
- if (memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
- memRefType.getRank() > 0) {
- return emitc::ArrayType::get(memRefType.getShape(),
- memRefType.getElementType());
- }
- return {};
- });
+ typeConverter.addConversion(
+ [&](MemRefType memRefType) -> std::optional<Type> {
+ if (!memRefType.hasStaticShape() ||
+ !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
+ return {};
+ }
+ Type convertedElementType =
+ typeConverter.convertType(memRefType.getElementType());
+ if (!convertedElementType)
+ return {};
+ return emitc::ArrayType::get(memRefType.getShape(),
+ convertedElementType);
+ });
}
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
>From d27fc8f68b7b3e3bab0ad5da7a40a6061208d8b8 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 21 Mar 2024 11:54:04 +0100
Subject: [PATCH 7/8] Remove conversion of func ops
---
.../MemRefToEmitC/MemRefToEmitC.cpp | 3 ++
.../MemRefToEmitC/MemRefToEmitCPass.cpp | 50 +++----------------
.../MemRefToEmitC/memref-to-emit-failed.mlir | 34 ++++---------
.../MemRefToEmitC/memref-to-emitc.mlir | 49 ++++--------------
4 files changed, 29 insertions(+), 107 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index c7215716449aab..a82312f5c02147 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -41,6 +41,9 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+ }
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
return success();
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index 4233a98fbf8ea4..4e5d1912d15729 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -14,11 +14,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -30,56 +26,24 @@ namespace mlir {
using namespace mlir;
namespace {
-
-/// Disallow all memrefs even though we only have conversions
-/// for memrefs with static shape right now to have good diagnostics.
-bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
-
-template <typename RangeT>
-bool areLegal(RangeT &&range) {
- return llvm::all_of(range, [](Type type) { return isLegal(type); });
-}
-
-bool isLegal(Operation *op) {
- return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
-}
-
-bool isSignatureLegal(FunctionType ty) {
- return areLegal(ty.getInputs()) && areLegal(ty.getResults());
-}
-
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
TypeConverter converter;
- // Fallback for other types.
- converter.addConversion([](Type type) { return type; });
- populateMemRefToEmitCTypeConversion(converter);
- converter.addConversion(
- [&converter](FunctionType ty) -> std::optional<Type> {
- SmallVector<Type> inputs;
- if (failed(converter.convertTypes(ty.getInputs(), inputs)))
- return std::nullopt;
- SmallVector<Type> results;
- if (failed(converter.convertTypes(ty.getResults(), results)))
- return std::nullopt;
+ // Fallback for other types.
+ converter.addConversion([](Type type) -> std::optional<Type> {
+ if (isa<MemRefType>(type))
+ return {};
+ return type;
+ });
- return FunctionType::get(ty.getContext(), inputs, results);
- });
+ populateMemRefToEmitCTypeConversion(converter);
RewritePatternSet patterns(&getContext());
populateMemRefToEmitCConversionPatterns(patterns, converter);
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
- converter);
- populateCallOpTypeConversionPattern(patterns, converter);
- populateReturnOpTypeConversionPattern(patterns, converter);
ConversionTarget target(getContext());
- target.addDynamicallyLegalOp<func::FuncOp>(
- [](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
- target.addDynamicallyLegalDialect<func::FuncDialect>(
- [](Operation *op) { return isLegal(op); });
target.addIllegalDialect<memref::MemRefDialect>();
target.addLegalDialect<emitc::EmitCDialect>();
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
index 4c728299186853..390190d341e5ae 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emit-failed.mlir
@@ -1,23 +1,7 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
-// Unranked memrefs are not converted
-// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
-func.func @memref_unranked(%arg0 : memref<*xf32>) {
- return
-}
-
-// -----
-
-// Memrefs with dynamic shapes are not converted
-// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
-func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
- return
-}
-
-// -----
-
func.func @memref_op(%arg0 : memref<2x4xf32>) {
- // expected-error at +1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+ // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
return
}
@@ -26,7 +10,7 @@ func.func @memref_op(%arg0 : memref<2x4xf32>) {
func.func @alloca_with_dynamic_shape() {
%0 = index.constant 1
- // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
%1 = memref.alloca(%0) : memref<4x?xf32>
return
}
@@ -34,21 +18,23 @@ func.func @alloca_with_dynamic_shape() {
// -----
func.func @alloca_with_alignment() {
- // expected-error at +1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
- %1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
return
}
// -----
-// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
-func.func @non_identity_layout(%arg0 : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>) {
+func.func @non_identity_layout() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
return
}
// -----
-// expected-error at +1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
-func.func @zero_rank(%arg0 : memref<f32>) {
+func.func @zero_rank() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() : memref<f32>
return
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 8a11dcf7603c62..ba79703adebff5 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -1,47 +1,16 @@
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
-// CHECK-LABEL: memref_arg
-// CHECK-SAME: !emitc.array<32xf32>)
-func.func @memref_arg(%arg0 : memref<32xf32>) {
- func.return
-}
-
-// -----
-
-// CHECK-LABEL: memref_return
-// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
-func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
-// CHECK: return %[[arg0]] : !emitc.array<32xf32>
- func.return %arg0 : memref<32xf32>
-}
-
-// CHECK-LABEL: memref_call
-// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
-func.func @memref_call(%arg0 : memref<32xf32>) {
-// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
- func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
- func.return
-}
-
-// -----
-
-// CHECK-LABEL: alloca
-func.func @alloca() {
- // CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+// CHECK-LABEL: memref
+// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref(%i: index, %j: index) {
+ // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- return
-}
-// -----
+ // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ %1 = memref.load %0[%i, %j] : memref<4x8xf32>
-// CHECK-LABEL: memref_load_store
-// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
-// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
-func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
- // CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
- %0 = memref.load %in[%i, %j] : memref<4x8xf32>
- // CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
- // CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
- memref.store %0, %out[%i, %j] : memref<3x5xf32>
+ // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: emitc.assign %[[LOAD]] : f32 to %[[SUBSCRIPT:.*]] : f32
+ memref.store %1, %0[%i, %j] : memref<4x8xf32>
return
}
>From 686324696cdc552b7f7757c9e5570c93f22fc92a Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Thu, 21 Mar 2024 13:12:07 +0100
Subject: [PATCH 8/8] Remove dependency on Func dialect libs
---
mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt | 2 --
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 2 --
2 files changed, 4 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
index 372c5b89d83008..8a72e747d024ae 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -13,8 +13,6 @@ add_mlir_conversion_library(MLIRMemRefToEmitC
LINK_LIBS PUBLIC
MLIREmitCDialect
- MLIRFuncDialect
- MLIRFuncTransforms
MLIRMemRefDialect
MLIRTransforms
)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2582fd78cd6c81..6e9add87e4c64c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8197,8 +8197,6 @@ cc_library(
deps = [
":ConversionPassIncGen",
":EmitCDialect",
- ":FuncDialect",
- ":FuncTransforms",
":MemRefDialect",
":IR",
":Pass",
More information about the Mlir-commits
mailing list