[Mlir-commits] [mlir] 0aa6d57 - [MLIR] Add initial convert-memref-to-emitc pass (#85389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 21 06:27:41 PDT 2024
Author: Matthias Gehre
Date: 2024-03-21T14:27:37+01:00
New Revision: 0aa6d57e575dd920db81bef7ff509c4d3a9c6891
URL: https://github.com/llvm/llvm-project/commit/0aa6d57e575dd920db81bef7ff509c4d3a9c6891
DIFF: https://github.com/llvm/llvm-project/commit/0aa6d57e575dd920db81bef7ff509c4d3a9c6891.diff
LOG: [MLIR] Add initial convert-memref-to-emitc pass (#85389)
This converts `memref.alloca`, `memref.load` & `memref.store` to
`emitc.variable`, `emitc.subscript` and `emitc.assign`.
Added:
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
new file mode 100644
index 00000000000000..734ffdba520c9f
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -0,0 +1,21 @@
+//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+class TypeConverter;
+
+void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
+
+void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
new file mode 100644
index 00000000000000..4a63014c19ad0e
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
@@ -0,0 +1,20 @@
+//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index f2aa4fb535402d..2179ae18ac074b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -45,6 +45,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..7e7ee3a2f780f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
];
}
+//===----------------------------------------------------------------------===//
+// MemRefToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
+ let summary = "Convert MemRef dialect to EmitC dialect";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// MemRefToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 8219cf98575f3c..41ab7046b91ce3 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
add_subdirectory(NVGPUToNVVM)
diff --git a/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..8a72e747d024ae
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRMemRefToEmitC
+ MemRefToEmitC.cpp
+ MemRefToEmitCPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREmitCDialect
+ MLIRMemRefDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
new file mode 100644
index 00000000000000..0e3b6469212640
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -0,0 +1,114 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with dynamic shape");
+ }
+
+ if (op.getAlignment().value_or(1) > 1) {
+ // TODO: Allow alignment if it is not more than the natural alignment
+ // of the C array.
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform alloca with alignment requirement");
+ }
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+ }
+ auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+ rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
+ return success();
+ }
+};
+
+struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+ }
+
+ auto subscript = rewriter.create<emitc::SubscriptOp>(
+ op.getLoc(), operands.getMemref(), operands.getIndices());
+
+ auto noInit = emitc::OpaqueAttr::get(getContext(), "");
+ auto var =
+ rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
+
+ rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
+ rewriter.replaceOp(op, var);
+ return success();
+ }
+};
+
+struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto subscript = rewriter.create<emitc::SubscriptOp>(
+ op.getLoc(), operands.getMemref(), operands.getIndices());
+ rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+ operands.getValue());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
+ typeConverter.addConversion(
+ [&](MemRefType memRefType) -> std::optional<Type> {
+ if (!memRefType.hasStaticShape() ||
+ !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
+ return {};
+ }
+ Type convertedElementType =
+ typeConverter.convertType(memRefType.getElementType());
+ if (!convertedElementType)
+ return {};
+ return emitc::ArrayType::get(memRefType.getShape(),
+ convertedElementType);
+ });
+}
+
+void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
+ TypeConverter &converter) {
+ patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
+ patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
new file mode 100644
index 00000000000000..4e5d1912d15729
--- /dev/null
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -0,0 +1,55 @@
+//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert memref ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
+
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertMemRefToEmitCPass
+ : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ void runOnOperation() override {
+ TypeConverter converter;
+
+ // Fallback for other types.
+ converter.addConversion([](Type type) -> std::optional<Type> {
+ if (isa<MemRefType>(type))
+ return {};
+ return type;
+ });
+
+ populateMemRefToEmitCTypeConversion(converter);
+
+ RewritePatternSet patterns(&getContext());
+ populateMemRefToEmitCConversionPatterns(patterns, converter);
+
+ ConversionTarget target(getContext());
+ target.addIllegalDialect<memref::MemRefDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
new file mode 100644
index 00000000000000..390190d341e5ae
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
+
+func.func @memref_op(%arg0 : memref<2x4xf32>) {
+ // expected-error at +1 {{failed to legalize operation 'memref.copy'}}
+ memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_dynamic_shape() {
+ %0 = index.constant 1
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %1 = memref.alloca(%0) : memref<4x?xf32>
+ return
+}
+
+// -----
+
+func.func @alloca_with_alignment() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
+ return
+}
+
+// -----
+
+func.func @non_identity_layout() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
+ return
+}
+
+// -----
+
+func.func @zero_rank() {
+ // expected-error at +1 {{failed to legalize operation 'memref.alloca'}}
+ %0 = memref.alloca() : memref<f32>
+ return
+}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
new file mode 100644
index 00000000000000..9793b2d6d7832f
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: memref_store
+// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_store(%v : f32, %i: index, %j: index) {
+ // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+ %0 = memref.alloca() : memref<4x8xf32>
+
+ // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
+ memref.store %v, %0[%i, %j] : memref<4x8xf32>
+ return
+}
+// -----
+
+// CHECK-LABEL: memref_load
+// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
+func.func @memref_load(%i: index, %j: index) -> f32 {
+ // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+ %0 = memref.alloca() : memref<4x8xf32>
+
+ // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
+ %1 = memref.load %0[%i, %j] : memref<4x8xf32>
+ // CHECK: return %[[VAR]] : f32
+ return %1 : f32
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 201c7f653398ae..3b575d4a413c31 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4186,6 +4186,7 @@ cc_library(
":MathToLLVM",
":MathToLibm",
":MathToSPIRV",
+ ":MemRefToEmitC",
":MemRefToLLVM",
":MemRefToSPIRV",
":NVGPUToNVVM",
@@ -8256,6 +8257,32 @@ cc_library(
],
)
+cc_library(
+ name = "MemRefToEmitC",
+ srcs = glob([
+ "lib/Conversion/MemRefToEmitC/*.cpp",
+ "lib/Conversion/MemRefToEmitC/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/MemRefToEmitC/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/MemRefToEmitC",
+ ],
+ deps = [
+ ":ConversionPassIncGen",
+ ":EmitCDialect",
+ ":MemRefDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":TransformUtils",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
+
cc_library(
name = "MemRefToLLVM",
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),
More information about the Mlir-commits
mailing list