[Mlir-commits] [mlir] 525d60b - [mlir][mem2reg] Add support for mem2reg in MemRef.
Tobias Gysi
llvmlistbot at llvm.org
Thu May 4 05:55:00 PDT 2023
Author: Théo Degioanni
Date: 2023-05-04T12:44:15Z
New Revision: 525d60bf3501629c4c0b1812086053efafcad37b
URL: https://github.com/llvm/llvm-project/commit/525d60bf3501629c4c0b1812086053efafcad37b
DIFF: https://github.com/llvm/llvm-project/commit/525d60bf3501629c4c0b1812086053efafcad37b.diff
LOG: [mlir][mem2reg] Add support for mem2reg in MemRef.
This patch implements the mem2reg interfaces for MemRef types. This only supports scalar memrefs of a small list of types. It would be beneficial to create more interfaces for default values before expanding support to more types. Additionally, I am working on an upcoming revision to bring SROA to MLIR that should help with non-scalar memrefs.
Reviewed By: gysit, Mogball
Differential Revision: https://reviews.llvm.org/D149441
Added:
mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
mlir/test/Dialect/LLVMIR/mem2reg.mlir
mlir/test/Dialect/MemRef/mem2reg.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
Removed:
mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir
mlir/test/Transforms/mem2reg-llvmir.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b599c3c56ca78..41f130e074de1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -16,6 +16,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/CopyOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/Mem2RegInterfaces.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e19e919e3e743..1ea0ef0057b7d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/Mem2RegInterfaces.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -309,7 +310,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
//===----------------------------------------------------------------------===//
def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
let summary = "stack memory allocation operation";
let description = [{
The `alloca` operation allocates memory on the stack, to be automatically
@@ -1159,7 +1161,8 @@ def LoadOp : MemRef_Op<"load",
[TypesMatchWith<"result type matches element type of 'memref'",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">,
- MemRefsNormalizable]> {
+ MemRefsNormalizable,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
let summary = "load operation";
let description = [{
The `load` op reads an element from a memref specified by an index list. The
@@ -1748,7 +1751,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
[TypesMatchWith<"type of 'value' matches element type of 'memref'",
"memref", "value",
"$_self.cast<MemRefType>().getElementType()">,
- MemRefsNormalizable]> {
+ MemRefsNormalizable,
+ DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
let summary = "store operation";
let description = [{
Store a value to a memref location given by indices. The value stored should
diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 3aedd3783fa8f..1b01f783a2224 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMemRefDialect
MemRefDialect.cpp
+ MemRefMem2Reg.cpp
MemRefOps.cpp
ValueBoundsOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
new file mode 100644
index 0000000000000..acc38b57c7fdb
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
@@ -0,0 +1,120 @@
+//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements Mem2Reg-related interfaces for MemRef dialect
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// AllocaOp interfaces
+//===----------------------------------------------------------------------===//
+
+static bool isSupportedElementType(Type type) {
+ return type.isa<MemRefType>() ||
+ OpBuilder(type.getContext()).getZeroAttr(type);
+}
+
+SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
+ MemRefType type = getType();
+ if (!isSupportedElementType(type.getElementType()))
+ return {};
+ if (!type.hasStaticShape())
+ return {};
+ // Make sure the memref contains only a single element.
+ if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; }))
+ return {};
+
+ return {MemorySlot{getResult(), type.getElementType()}};
+}
+
+Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
+ OpBuilder &builder) {
+ assert(isSupportedElementType(slot.elemType));
+ // TODO: support more types.
+ return TypeSwitch<Type, Value>(slot.elemType)
+ .Case([&](MemRefType t) {
+ return builder.create<memref::AllocaOp>(getLoc(), t);
+ })
+ .Default([&](Type t) {
+ return builder.create<arith::ConstantOp>(getLoc(), t,
+ builder.getZeroAttr(t));
+ });
+}
+
+void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+ Value defaultValue) {
+ if (defaultValue.use_empty())
+ defaultValue.getDefiningOp()->erase();
+ erase();
+}
+
+void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
+ BlockArgument argument,
+ OpBuilder &builder) {}
+
+//===----------------------------------------------------------------------===//
+// LoadOp/StoreOp interfaces
+//===----------------------------------------------------------------------===//
+
+bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
+ return getMemRef() == slot.ptr;
+}
+
+Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+
+bool memref::LoadOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ if (blockingUses.size() != 1)
+ return false;
+ Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+ getResult().getType() == slot.elemType;
+}
+
+DeletionKind memref::LoadOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ OpBuilder &builder, Value reachingDefinition) {
+ // `canUsesBeRemoved` checked this blocking use must be the loaded slot
+ // pointer.
+ getResult().replaceAllUsesWith(reachingDefinition);
+ return DeletionKind::Delete;
+}
+
+bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+Value memref::StoreOp::getStored(const MemorySlot &slot) {
+ if (getMemRef() != slot.ptr)
+ return {};
+ return getValue();
+}
+
+bool memref::StoreOp::canUsesBeRemoved(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ SmallVectorImpl<OpOperand *> &newBlockingUses) {
+ if (blockingUses.size() != 1)
+ return false;
+ Value blockingUse = (*blockingUses.begin())->get();
+ return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+ getValue() != slot.ptr && getValue().getType() == slot.elemType;
+}
+
+DeletionKind memref::StoreOp::removeBlockingUses(
+ const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+ OpBuilder &builder, Value reachingDefinition) {
+ return DeletionKind::Delete;
+}
diff --git a/mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
similarity index 100%
rename from mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir
rename to mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
diff --git a/mlir/test/Transforms/mem2reg-llvmir.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
similarity index 100%
rename from mlir/test/Transforms/mem2reg-llvmir.mlir
rename to mlir/test/Dialect/LLVMIR/mem2reg.mlir
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
new file mode 100644
index 0000000000000..86707ac0b4971
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic
+func.func @basic() -> i32 {
+ // CHECK-NOT: = memref.alloca
+ // CHECK: %[[RES:.*]] = arith.constant 5 : i32
+ // CHECK-NOT: = memref.alloca
+ %0 = arith.constant 5 : i32
+ %1 = memref.alloca() : memref<i32>
+ memref.store %0, %1[] : memref<i32>
+ %2 = memref.load %1[] : memref<i32>
+ // CHECK: return %[[RES]] : i32
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_default
+func.func @basic_default() -> i32 {
+ // CHECK-NOT: = memref.alloca
+ // CHECK: %[[RES:.*]] = arith.constant 0 : i32
+ // CHECK-NOT: = memref.alloca
+ %0 = arith.constant 5 : i32
+ %1 = memref.alloca() : memref<i32>
+ %2 = memref.load %1[] : memref<i32>
+ // CHECK-NOT: memref.store
+ memref.store %0, %1[] : memref<i32>
+ // CHECK: return %[[RES]] : i32
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_float
+func.func @basic_float() -> f32 {
+ // CHECK-NOT: = memref.alloca
+ // CHECK: %[[RES:.*]] = arith.constant {{.*}} : f32
+ %0 = arith.constant 5.2 : f32
+ // CHECK-NOT: = memref.alloca
+ %1 = memref.alloca() : memref<f32>
+ memref.store %0, %1[] : memref<f32>
+ %2 = memref.load %1[] : memref<f32>
+ // CHECK: return %[[RES]] : f32
+ return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_ranked
+func.func @basic_ranked() -> i32 {
+ // CHECK-NOT: = memref.alloca
+ // CHECK: %[[RES:.*]] = arith.constant 5 : i32
+ // CHECK-NOT: = memref.alloca
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 5 : i32
+ %2 = memref.alloca() : memref<1x1xi32>
+ memref.store %1, %2[%0, %0] : memref<1x1xi32>
+ %3 = memref.load %2[%0, %0] : memref<1x1xi32>
+ // CHECK: return %[[RES]] : i32
+ return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @reject_multiple_elements
+func.func @reject_multiple_elements() -> i32 {
+ // CHECK: %[[INDEX:.*]] = arith.constant 0 : index
+ %0 = arith.constant 0 : index
+ // CHECK: %[[STORED:.*]] = arith.constant 5 : i32
+ %1 = arith.constant 5 : i32
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca()
+ %2 = memref.alloca() : memref<1x2xi32>
+ // CHECK: memref.store %[[STORED]], %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
+ memref.store %1, %2[%0, %0] : memref<1x2xi32>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
+ %3 = memref.load %2[%0, %0] : memref<1x2xi32>
+ // CHECK: return %[[RES]] : i32
+ return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cycle
+// CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i64)
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+ // CHECK-NOT: = memref.alloca
+ %alloca = memref.alloca() : memref<i64>
+ memref.store %arg2, %alloca[] : memref<i64>
+ // CHECK: cf.cond_br %[[ARG1:.*]], ^[[BB1:.*]](%[[ARG2]] : i64), ^[[BB2:.*]](%[[ARG2]] : i64)
+ cf.cond_br %arg1, ^bb1, ^bb2
+// CHECK: ^[[BB1]](%[[USE:.*]]: i64):
+^bb1:
+ %use = memref.load %alloca[] : memref<i64>
+ // CHECK: call @use(%[[USE]])
+ func.call @use(%use) : (i64) -> ()
+ memref.store %arg0, %alloca[] : memref<i64>
+ // CHECK: cf.br ^[[BB2]](%[[ARG0]] : i64)
+ cf.br ^bb2
+// CHECK: ^[[BB2]](%[[FWD:.*]]: i64):
+^bb2:
+ // CHECK: cf.br ^[[BB1]](%[[FWD]] : i64)
+ cf.br ^bb1
+}
+
+func.func @use(%arg: i64) { return }
+
+// -----
+
+// CHECK-LABEL: func.func @recursive
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+func.func @recursive(%arg: i64) -> i64 {
+ // CHECK-NOT: = memref.alloca()
+ %alloca0 = memref.alloca() : memref<memref<memref<i64>>>
+ %alloca1 = memref.alloca() : memref<memref<i64>>
+ %alloca2 = memref.alloca() : memref<i64>
+ memref.store %arg, %alloca2[] : memref<i64>
+ memref.store %alloca2, %alloca1[] : memref<memref<i64>>
+ memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
+ %load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
+ %load1 = memref.load %load0[] : memref<memref<i64>>
+ %load2 = memref.load %load1[] : memref<i64>
+ // CHECK: return %[[ARG]] : i64
+ return %load2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func.func @deny_store_of_alloca
+// CHECK-SAME: (%[[ARG:.*]]: memref<memref<i32>>)
+func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
+ // CHECK: %[[VALUE:.*]] = arith.constant 5 : i32
+ %0 = arith.constant 5 : i32
+ // CHECK: %[[ALLOCA:.*]] = memref.alloca
+ %1 = memref.alloca() : memref<i32>
+ // Storing into the memref is allowed.
+ // CHECK: memref.store %[[VALUE]], %[[ALLOCA]][]
+ memref.store %0, %1[] : memref<i32>
+ // Storing the memref itself is NOT allowed.
+ // CHECK: memref.store %[[ALLOCA]], %[[ARG]][]
+ memref.store %1, %arg[] : memref<memref<i32>>
+ // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][]
+ %2 = memref.load %1[] : memref<i32>
+ // CHECK: return %[[RES]] : i32
+ return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
+func.func @promotable_nonpromotable_intertwined() -> i32 {
+ // CHECK: %[[VAL:.*]] = arith.constant 5 : i32
+ %0 = arith.constant 5 : i32
+ // CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
+ %1 = memref.alloca() : memref<i32>
+ // CHECK-NOT: = memref.alloca() : memref<memref<i32>>
+ %2 = memref.alloca() : memref<memref<i32>>
+ memref.store %1, %2[] : memref<memref<i32>>
+ %3 = memref.load %2[] : memref<memref<i32>>
+ // CHECK: call @use(%[[NON_PROMOTED]])
+ call @use(%1) : (memref<i32>) -> ()
+ // CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
+ %4 = memref.load %1[] : memref<i32>
+ // CHECK: return %[[RES]] : i32
+ return %4 : i32
+}
+
+func.func @use(%arg: memref<i32>) { return }
More information about the Mlir-commits
mailing list