[Mlir-commits] [mlir] 991cb14 - [mlir][memref][transform] Add new alloca_to_global op. (#66511)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 21 09:17:06 PDT 2023


Author: Ingo Müller
Date: 2023-09-21T18:17:00+02:00
New Revision: 991cb147152ab22ad0bc9f642fc221eccd2b8e37

URL: https://github.com/llvm/llvm-project/commit/991cb147152ab22ad0bc9f642fc221eccd2b8e37
DIFF: https://github.com/llvm/llvm-project/commit/991cb147152ab22ad0bc9f642fc221eccd2b8e37.diff

LOG: [mlir][memref][transform] Add new alloca_to_global op. (#66511)

This PR adds a new transform op that replaces `memref.alloca`s with
`memref.get_global`s to newly inserted `memref.global`s. This is useful,
for example, for allocations that should reside in the shared memory of
a GPU, which have to be declared as globals.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
    mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
    mlir/python/mlir/dialects/_memref_transform_ops_ext.py
    mlir/test/Dialect/MemRef/transform-ops.mlir
    mlir/test/python/dialects/transform_memref_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 681759f970cb910..d7bd8410e360a76 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -144,6 +144,69 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
 }
 
 def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">;
+
+def MemRefAllocaToGlobalOp :
+  Op<Transform_Dialect, "memref.alloca_to_global",
+     [TransformOpInterface,
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Inserts a new `memref.global` for each provided `memref.alloca` into the
+    nearest symbol table (e.g., a `builtin.module`) and replaces it with a
+    `memref.get_global`. This is useful, for example, for allocations that
+    should reside in the shared memory of a GPU, which have to be declared as
+    globals.
+
+    #### Example
+
+    Consider the following transform op:
+
+    ```mlir
+    %get_global, %global =
+        transform.memref.alloca_to_global %alloca
+          : (!transform.op<"memref.alloca">)
+            -> (!transform.any_op, !transform.any_op)
+    ```
+
+    and the following input payload:
+
+    ```mlir
+    module {
+      func.func @func() {
+        %alloca = memref.alloca() : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    then applying the transform op to the payload would result in the following
+    output IR:
+
+    ```mlir
+    module {
+      memref.global "private" @alloc : memref<2x32xf32>
+      func.func @func() {
+        %alloca = memref.get_global @alloc : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    #### Return modes
+
+    Succeeds always. The returned handles refer to the `memref.get_global` and
+    `memref.global` ops that were inserted by the transformation.
+  }];
+
+  let arguments = (ins Transform_MemRefAllocaOp:$alloca);
+  let results = (outs TransformHandleTypeInterface:$getGlobal,
+                  TransformHandleTypeInterface:$global);
+
+  let assemblyFormat = [{
+    $alloca attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,

diff  --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 58f4d8d8f6d21fe..eed29efcaaada88 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -126,6 +126,67 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector<memref::GlobalOp> globalOps;
+  SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast<memref::AllocaOp>(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca->getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Find nearest symbol table.
+      Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+      assert(symbolTableOp && "expected alloca payload to be in symbol table");
+      SymbolTable symbolTable(symbolTableOp);
+
+      // Insert a `memref.global` into the symbol table.
+      Type resultType = alloca.getResult().getType();
+      OpBuilder builder(rewriter.getContext());
+      // TODO: Add a better builder for this.
+      globalOp = builder.create<memref::GlobalOp>(
+          loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
+          TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+      symbolTable.insert(globalOp);
+    }
+
+    // Replace the `memref.alloca` with a `memref.get_global` accessing the
+    // global symbol inserted above.
+    rewriter.setInsertionPoint(alloca);
+    auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+        alloca, globalOp.getType(), globalOp.getName());
+
+    globalOps.push_back(globalOp);
+    getGlobalOps.push_back(getGlobalOp);
+  }
+
+  // Assemble results.
+  results.set(getGlobal().cast<OpResult>(), globalOps);
+  results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MemRefAllocaToGlobalOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  producesHandle(getGlobal(), effects);
+  producesHandle(getGetGlobal(), effects);
+  consumesHandle(getAlloca(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 4afe8e7b887f68e..1cc00bdcbf381c9 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -11,6 +11,52 @@
 from typing import Optional, overload, Union
 
 
+class MemRefAllocaToGlobalOp:
+    """Specialization for MemRefAllocaToGlobalOp class."""
+
+    @overload
+    def __init__(
+        self,
+        get_global_type: Type,
+        global_type: Type,
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
+        ...
+
+    def __init__(
+        self,
+        get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
+        global_type_or_none: Optional[Type] = None,
+        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(get_global_type_or_alloca, Type):
+            get_global_type = get_global_type_or_alloca
+            global_type = global_type_or_none
+            alloca = alloca_or_none
+        else:
+            get_global_type = transform.AnyOpType.get()
+            global_type = transform.AnyOpType.get()
+            alloca = get_global_type_or_alloca
+
+        super().__init__(
+            get_global_type,
+            global_type,
+            alloca,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class MemRefMultiBufferOp:
     """Specialization for MemRefMultiBufferOp class."""
 

diff  --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index b19db447af1c28a..68fea1f8402955c 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -1,5 +1,36 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
 
+// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
+// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
+
+// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index)
+func.func @func(%lb: index, %ub: index) {
+  // CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]])
+  scf.forall (%arg0, %arg1) in (%lb, %ub) {
+    // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
+    // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    %cst = arith.constant 0.0 : f32
+    %mr0 = memref.alloca() : memref<2x32xf32>
+    %mr1 = memref.alloca() : memref<2x32xf32>
+    memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32>
+    memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32>
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
+      : (!transform.any_op) -> !transform.op<"memref.alloca">
+  %get_global, %global = transform.memref.alloca_to_global %alloca
+        : (!transform.op<"memref.alloca">)
+          -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 

diff  --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f89005cb2f86d1b..e7d871c9eac8c7f 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -16,6 +16,41 @@ def run(f):
     return f
 
 
+ at run
+def testMemRefAllocaToAllocOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloca"),
+    )
+    with InsertionPoint(sequence.body):
+        memref.MemRefAllocaToGlobalOp(sequence.bodyTarget)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: (!transform.op<"memref.alloca">)
+    # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMemRefAllocaToAllocOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloca"),
+    )
+    with InsertionPoint(sequence.body):
+        memref.MemRefAllocaToGlobalOp(
+            transform.OperationType.get("memref.get_global"),
+            transform.OperationType.get("memref.global"),
+            sequence.bodyTarget,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
+
+
 @run
 def testMemRefMultiBufferOpCompact():
     sequence = transform.SequenceOp(


        


More information about the Mlir-commits mailing list