[Mlir-commits] [mlir] 525d60b - [mlir][mem2reg] Add support for mem2reg in MemRef.

Tobias Gysi llvmlistbot at llvm.org
Thu May 4 05:55:00 PDT 2023


Author: Théo Degioanni
Date: 2023-05-04T12:44:15Z
New Revision: 525d60bf3501629c4c0b1812086053efafcad37b

URL: https://github.com/llvm/llvm-project/commit/525d60bf3501629c4c0b1812086053efafcad37b
DIFF: https://github.com/llvm/llvm-project/commit/525d60bf3501629c4c0b1812086053efafcad37b.diff

LOG: [mlir][mem2reg] Add support for mem2reg in MemRef.

This patch implements the mem2reg interfaces for MemRef types. This only supports scalar memrefs of a small list of types. It would be beneficial to create more interfaces for default values before expanding support to more types. Additionally, I am working on an upcoming revision to bring SROA to MLIR that should help with non-scalar memrefs.

Reviewed By: gysit, Mogball

Differential Revision: https://reviews.llvm.org/D149441

Added: 
    mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
    mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
    mlir/test/Dialect/LLVMIR/mem2reg.mlir
    mlir/test/Dialect/MemRef/mem2reg.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Removed: 
    mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir
    mlir/test/Transforms/mem2reg-llvmir.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b599c3c56ca78..41f130e074de1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -16,6 +16,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/Mem2RegInterfaces.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e19e919e3e743..1ea0ef0057b7d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/Mem2RegInterfaces.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -309,7 +310,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
 //===----------------------------------------------------------------------===//
 
 def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
   let summary = "stack memory allocation operation";
   let description = [{
     The `alloca` operation allocates memory on the stack, to be automatically
@@ -1159,7 +1161,8 @@ def LoadOp : MemRef_Op<"load",
      [TypesMatchWith<"result type matches element type of 'memref'",
                      "memref", "result",
                      "$_self.cast<MemRefType>().getElementType()">,
-                     MemRefsNormalizable]> {
+      MemRefsNormalizable,
+      DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
   let summary = "load operation";
   let description = [{
     The `load` op reads an element from a memref specified by an index list. The
@@ -1748,7 +1751,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
      [TypesMatchWith<"type of 'value' matches element type of 'memref'",
                      "memref", "value",
                      "$_self.cast<MemRefType>().getElementType()">,
-                     MemRefsNormalizable]> {
+      MemRefsNormalizable,
+      DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
   let summary = "store operation";
   let description = [{
     Store a value to a memref location given by indices. The value stored should

diff  --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
index 3aedd3783fa8f..1b01f783a2224 100644
--- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
+  MemRefMem2Reg.cpp
   MemRefOps.cpp
   ValueBoundsOpInterfaceImpl.cpp
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
new file mode 100644
index 0000000000000..acc38b57c7fdb
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp
@@ -0,0 +1,120 @@
+//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements Mem2Reg-related interfaces for MemRef dialect
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/Mem2RegInterfaces.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+//  AllocaOp interfaces
+//===----------------------------------------------------------------------===//
+
+static bool isSupportedElementType(Type type) {
+  return type.isa<MemRefType>() ||
+         OpBuilder(type.getContext()).getZeroAttr(type);
+}
+
+SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
+  MemRefType type = getType();
+  if (!isSupportedElementType(type.getElementType()))
+    return {};
+  if (!type.hasStaticShape())
+    return {};
+  // Make sure the memref contains only a single element.
+  if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; }))
+    return {};
+
+  return {MemorySlot{getResult(), type.getElementType()}};
+}
+
+Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
+                                        OpBuilder &builder) {
+  assert(isSupportedElementType(slot.elemType));
+  // TODO: support more types.
+  return TypeSwitch<Type, Value>(slot.elemType)
+      .Case([&](MemRefType t) {
+        return builder.create<memref::AllocaOp>(getLoc(), t);
+      })
+      .Default([&](Type t) {
+        return builder.create<arith::ConstantOp>(getLoc(), t,
+                                                 builder.getZeroAttr(t));
+      });
+}
+
+void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
+                                               Value defaultValue) {
+  if (defaultValue.use_empty())
+    defaultValue.getDefiningOp()->erase();
+  erase();
+}
+
+void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
+                                           BlockArgument argument,
+                                           OpBuilder &builder) {}
+
+//===----------------------------------------------------------------------===//
+//  LoadOp/StoreOp interfaces
+//===----------------------------------------------------------------------===//
+
+bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
+  return getMemRef() == slot.ptr;
+}
+
+Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
+
+bool memref::LoadOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+  if (blockingUses.size() != 1)
+    return false;
+  Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+         getResult().getType() == slot.elemType;
+}
+
+DeletionKind memref::LoadOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    OpBuilder &builder, Value reachingDefinition) {
+  // `canUsesBeRemoved` checked this blocking use must be the loaded slot
+  // pointer.
+  getResult().replaceAllUsesWith(reachingDefinition);
+  return DeletionKind::Delete;
+}
+
+bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
+
+Value memref::StoreOp::getStored(const MemorySlot &slot) {
+  if (getMemRef() != slot.ptr)
+    return {};
+  return getValue();
+}
+
+bool memref::StoreOp::canUsesBeRemoved(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    SmallVectorImpl<OpOperand *> &newBlockingUses) {
+  if (blockingUses.size() != 1)
+    return false;
+  Value blockingUse = (*blockingUses.begin())->get();
+  return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
+         getValue() != slot.ptr && getValue().getType() == slot.elemType;
+}
+
+DeletionKind memref::StoreOp::removeBlockingUses(
+    const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
+    OpBuilder &builder, Value reachingDefinition) {
+  return DeletionKind::Delete;
+}

diff  --git a/mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir b/mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
similarity index 100%
rename from mlir/test/Transforms/mem2reg-llvmir-dbginfo.mlir
rename to mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir

diff  --git a/mlir/test/Transforms/mem2reg-llvmir.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
similarity index 100%
rename from mlir/test/Transforms/mem2reg-llvmir.mlir
rename to mlir/test/Dialect/LLVMIR/mem2reg.mlir

diff  --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
new file mode 100644
index 0000000000000..86707ac0b4971
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @basic
+func.func @basic() -> i32 {
+  // CHECK-NOT: = memref.alloca
+  // CHECK: %[[RES:.*]] = arith.constant 5 : i32
+  // CHECK-NOT: = memref.alloca
+  %0 = arith.constant 5 : i32
+  %1 = memref.alloca() : memref<i32>
+  memref.store %0, %1[] : memref<i32>
+  %2 = memref.load %1[] : memref<i32>
+  // CHECK: return %[[RES]] : i32
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_default
+func.func @basic_default() -> i32 {
+  // CHECK-NOT: = memref.alloca
+  // CHECK: %[[RES:.*]] = arith.constant 0 : i32
+  // CHECK-NOT: = memref.alloca
+  %0 = arith.constant 5 : i32
+  %1 = memref.alloca() : memref<i32>
+  %2 = memref.load %1[] : memref<i32>
+  // CHECK-NOT: memref.store
+  memref.store %0, %1[] : memref<i32>
+  // CHECK: return %[[RES]] : i32
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_float
+func.func @basic_float() -> f32 {
+  // CHECK-NOT: = memref.alloca
+  // CHECK: %[[RES:.*]] = arith.constant {{.*}} : f32
+  %0 = arith.constant 5.2 : f32
+  // CHECK-NOT: = memref.alloca
+  %1 = memref.alloca() : memref<f32>
+  memref.store %0, %1[] : memref<f32>
+  %2 = memref.load %1[] : memref<f32>
+  // CHECK: return %[[RES]] : f32
+  return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @basic_ranked
+func.func @basic_ranked() -> i32 {
+  // CHECK-NOT: = memref.alloca
+  // CHECK: %[[RES:.*]] = arith.constant 5 : i32
+  // CHECK-NOT: = memref.alloca
+  %0 = arith.constant 0 : index
+  %1 = arith.constant 5 : i32
+  %2 = memref.alloca() : memref<1x1xi32>
+  memref.store %1, %2[%0, %0] : memref<1x1xi32>
+  %3 = memref.load %2[%0, %0] : memref<1x1xi32>
+  // CHECK: return %[[RES]] : i32
+  return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @reject_multiple_elements
+func.func @reject_multiple_elements() -> i32 {
+  // CHECK: %[[INDEX:.*]] = arith.constant 0 : index
+  %0 = arith.constant 0 : index
+  // CHECK: %[[STORED:.*]] = arith.constant 5 : i32
+  %1 = arith.constant 5 : i32
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca()
+  %2 = memref.alloca() : memref<1x2xi32>
+  // CHECK: memref.store %[[STORED]], %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
+  memref.store %1, %2[%0, %0] : memref<1x2xi32>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
+  %3 = memref.load %2[%0, %0] : memref<1x2xi32>
+  // CHECK: return %[[RES]] : i32
+  return %3 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cycle
+// CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i64)
+func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
+  // CHECK-NOT: = memref.alloca
+  %alloca = memref.alloca() : memref<i64>
+  memref.store %arg2, %alloca[] : memref<i64>
+  // CHECK: cf.cond_br %[[ARG1:.*]], ^[[BB1:.*]](%[[ARG2]] : i64), ^[[BB2:.*]](%[[ARG2]] : i64)
+  cf.cond_br %arg1, ^bb1, ^bb2
+// CHECK: ^[[BB1]](%[[USE:.*]]: i64):
+^bb1:
+  %use = memref.load %alloca[] : memref<i64>
+  // CHECK: call @use(%[[USE]])
+  func.call @use(%use) : (i64) -> ()
+  memref.store %arg0, %alloca[] : memref<i64>
+  // CHECK: cf.br ^[[BB2]](%[[ARG0]] : i64)
+  cf.br ^bb2
+// CHECK: ^[[BB2]](%[[FWD:.*]]: i64):
+^bb2:
+  // CHECK: cf.br ^[[BB1]](%[[FWD]] : i64)
+  cf.br ^bb1
+}
+
+func.func @use(%arg: i64) { return }
+
+// -----
+
+// CHECK-LABEL: func.func @recursive
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+func.func @recursive(%arg: i64) -> i64 {
+  // CHECK-NOT: = memref.alloca()
+  %alloca0 = memref.alloca() : memref<memref<memref<i64>>>
+  %alloca1 = memref.alloca() : memref<memref<i64>>
+  %alloca2 = memref.alloca() : memref<i64>
+  memref.store %arg, %alloca2[] : memref<i64>
+  memref.store %alloca2, %alloca1[] : memref<memref<i64>>
+  memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
+  %load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
+  %load1 = memref.load %load0[] : memref<memref<i64>>
+  %load2 = memref.load %load1[] : memref<i64>
+  // CHECK: return %[[ARG]] : i64
+  return %load2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: func.func @deny_store_of_alloca
+// CHECK-SAME: (%[[ARG:.*]]: memref<memref<i32>>)
+func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
+  // CHECK: %[[VALUE:.*]] = arith.constant 5 : i32
+  %0 = arith.constant 5 : i32
+  // CHECK: %[[ALLOCA:.*]] = memref.alloca
+  %1 = memref.alloca() : memref<i32>
+  // Storing into the memref is allowed.
+  // CHECK: memref.store %[[VALUE]], %[[ALLOCA]][]
+  memref.store %0, %1[] : memref<i32>
+  // Storing the memref itself is NOT allowed.
+  // CHECK: memref.store %[[ALLOCA]], %[[ARG]][]
+  memref.store %1, %arg[] : memref<memref<i32>>
+  // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][]
+  %2 = memref.load %1[] : memref<i32>
+  // CHECK: return %[[RES]] : i32
+  return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
+func.func @promotable_nonpromotable_intertwined() -> i32 {
+  // CHECK: %[[VAL:.*]] = arith.constant 5 : i32
+  %0 = arith.constant 5 : i32
+  // CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
+  %1 = memref.alloca() : memref<i32>
+  // CHECK-NOT: = memref.alloca() : memref<memref<i32>>
+  %2 = memref.alloca() : memref<memref<i32>>
+  memref.store %1, %2[] : memref<memref<i32>>
+  %3 = memref.load %2[] : memref<memref<i32>>
+  // CHECK: call @use(%[[NON_PROMOTED]])
+  call @use(%1) : (memref<i32>) -> ()
+  // CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
+  %4 = memref.load %1[] : memref<i32>
+  // CHECK: return %[[RES]] : i32
+  return %4 : i32
+}
+
+func.func @use(%arg: memref<i32>) { return }


        


More information about the Mlir-commits mailing list