[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 06:29:19 PDT 2023


Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/66511/mlir at github.com>


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core
            
<details>
<summary>Changes</summary>
This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.
--
Full diff: https://github.com/llvm/llvm-project/pull/66511.diff

7 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td (+65) 
- (modified) mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp (+90) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+3-2) 
- (modified) mlir/python/mlir/dialects/_memref_transform_ops_ext.py (+58) 
- (modified) mlir/test/Dialect/MemRef/transform-ops.mlir (+39) 
- (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+12) 
- (modified) mlir/test/python/dialects/transform_memref_ext.py (+48) 


<pre>
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 681759f970cb910..6a78784d74dd53c 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op&lt;Transform_Dialect,
 }
 
 def Transform_MemRefAllocOp : Transform_ConcreteOpType&lt;&quot;memref.alloc&quot;&gt;;
+def Transform_MemRefAllocaOp : Transform_ConcreteOpType&lt;&quot;memref.alloca&quot;&gt;;
+
+def MemRefAllocaToGlobalOp :
+  Op&lt;Transform_Dialect, &quot;memref.alloca_to_global&quot;,
+     [TransformOpInterface,
+      DeclareOpInterfaceMethods&lt;MemoryEffectsOpInterface&gt;,
+      DeclareOpInterfaceMethods&lt;TransformOpInterface&gt;]&gt; {
+  let description = [{
+    Inserts a new `memref.global` for each provided `memref.alloca` into the
+    provided module and replaces it with a `memref.get_global`. This is useful,
+    for example, for allocations that should reside in the shared memory of
+    a GPU, which have to be declared as globals.
+
+    #### Example
+
+    Consider the following transform op:
+
+    ```mlir
+    %get_global, %global =
+        transform.memref.alloca_to_global %alloca in %module
+          : (!transform.op&lt;&quot;builtin.module&quot;&gt;, !transform.op&lt;&quot;memref.alloca&quot;&gt;)
+            -&gt; (!transform.any_op, !transform.any_op)
+    ```
+
+    and the following input payload:
+
+    ```mlir
+    module {
+      func.func @func() {
+        %alloca = memref.alloca() : memref&lt;2x32xf32&gt;
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    then applying the transform op to the payload would result in the following
+    output IR:
+
+    ```mlir
+    module {
+      memref.global &quot;private&quot; @alloc : memref&lt;2x32xf32&gt;
+      func.func @func() {
+        %alloca = memref.get_global @alloc : memref&lt;2x32xf32&gt;
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    #### Return modes
+
+    Emits a definite failure if not exactly one `module` payload op was provided
+    or any of the `alloca` payload ops is not inside that module, and succeeds
+    otherwise. The returned handles refer to the `memref.get_global` and
+    `memref.global` ops that were inserted by the transformation.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType&lt;&quot;builtin.module&quot;&gt;:$module,
+                   Transform_MemRefAllocaOp:$alloca);
+  let results = (outs TransformHandleTypeInterface:$get_global,
+                  TransformHandleTypeInterface:$global);
+
+  let assemblyFormat = [{
+    $alloca `in` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 def MemRefMultiBufferOp : Op&lt;Transform_Dialect, &quot;memref.multibuffer&quot;,
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 58f4d8d8f6d21fe..7467359da83c37f 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString&lt;64&gt; getUniqueSymbol(llvm::StringRef prefix,
+                                             ModuleOp module) {
+  llvm::SmallString&lt;64&gt; candidateNameStorage;
+  StringRef candidateName(prefix);
+  int uniqueNumber = 0;
+  while (true) {
+    if (!module.lookupSymbol(candidateName)) {
+      break;
+    }
+    candidateNameStorage.clear();
+    candidateName = (prefix + Twine(&quot;_&quot;) + Twine(uniqueNumber))
+                        .toStringRef(candidateNameStorage);
+    uniqueNumber++;
+  }
+  return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &amp;rewriter,
+                                         transform::TransformResults &amp;results,
+                                         transform::TransformState &amp;state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector&lt;memref::GlobalOp&gt; globalOps;
+  SmallVector&lt;memref::GetGlobalOp&gt; getGlobalOps;
+
+  // Get `builtin.module`.
+  auto moduleOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(moduleOps)) {
+    return emitDefiniteFailure()
+           &lt;&lt; Twine(&quot;expected exactly one &#x27;module&#x27; payload, but found &quot;) +
+                  std::to_string(llvm::range_size(moduleOps));
+  }
+  ModuleOp module = cast&lt;ModuleOp&gt;(*moduleOps.begin());
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast&lt;memref::AllocaOp&gt;(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca-&gt;getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Insert a `memref.global` at the beginning of the module.
+      if (module != alloca-&gt;getParentOfType&lt;ModuleOp&gt;()) {
+        return emitDefiniteFailure()
+               &lt;&lt; &quot;expected &#x27;alloca&#x27; payload to be inside &#x27;module&#x27; payload&quot;;
+      }
+      IRRewriter::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&amp;module.getBodyRegion().front());
+      Type resultType = alloca.getResult().getType();
+      llvm::SmallString&lt;64&gt; symName = getUniqueSymbol(&quot;alloca&quot;, module);
+      // XXX: Add a better builder for this.
+      globalOp = rewriter.create&lt;memref::GlobalOp&gt;(
+          loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, &quot;private&quot;),
+          TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+    }
+
+    // Replace the `memref.alloca` with a `memref.get_global` accessing the
+    // global symbol inserted above.
+    rewriter.setInsertionPoint(alloca);
+    auto getGlobalOp = rewriter.replaceOpWithNewOp&lt;memref::GetGlobalOp&gt;(
+        alloca, globalOp.getType(), globalOp.getName());
+
+    globalOps.push_back(globalOp);
+    getGlobalOps.push_back(getGlobalOp);
+  }
+
+  // Assemble results.
+  results.set(getGlobal().cast&lt;OpResult&gt;(), globalOps);
+  results.set(getGetGlobal().cast&lt;OpResult&gt;(), getGlobalOps);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MemRefAllocaToGlobalOp::getEffects(
+    SmallVectorImpl&lt;MemoryEffects::EffectInstance&gt; &amp;effects) {
+  onlyReadsHandle(getModule(), effects);
+  producesHandle(getGlobal(), effects);
+  producesHandle(getGetGlobal(), effects);
+  consumesHandle(getAlloca(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index de3cd1b28e435bc..f1d07b85adb7576 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &amp;rewriter,
   DenseSet&lt;Operation *&gt; resultSet;
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Operation *parent = target-&gt;getParentOp();
-    do {
+    while (parent) {
       bool checkIsolatedFromAbove =
           !getIsolatedFromAbove() ||
           parent-&gt;hasTrait&lt;OpTrait::IsIsolatedFromAbove&gt;();
@@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &amp;rewriter,
                          parent-&gt;getName().getStringRef() == *getOpName();
       if (checkIsolatedFromAbove &amp;&amp; checkOpName)
         break;
-    } while ((parent = parent-&gt;getParentOp()));
+      parent = parent-&gt;getParentOp();
+    }
     if (!parent) {
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 4afe8e7b887f68e..56dcfbe5655e9b6 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -11,6 +11,64 @@
 from typing import Optional, overload, Union
 
 
+class MemRefAllocaToGlobalOp:
+    &quot;&quot;&quot;Specialization for MemRefAllocaToGlobalOp class.&quot;&quot;&quot;
+
+    @overload
+    def __init__(
+        self,
+        get_global_type: Type,
+        global_type: Type,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    def __init__(
+        self,
+        get_global_type_or_module: Union[Operation, OpView, Type, Value],
+        global_type_or_alloca: Union[Operation, OpView, Type, Value],
+        module_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(get_global_type_or_module, Type):
+            get_global_type = get_global_type_or_module
+            global_type = global_type_or_alloca
+            module = module_or_none
+            alloca = alloca_or_none
+        else:
+            get_global_type = transform.AnyOpType.get()
+            global_type = transform.AnyOpType.get()
+            module = get_global_type_or_module
+            alloca = global_type_or_alloca
+
+        super().__init__(
+            get_global_type,
+            global_type,
+            module,
+            alloca,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class MemRefMultiBufferOp:
     &quot;&quot;&quot;Specialization for MemRefMultiBufferOp class.&quot;&quot;&quot;
 
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index b19db447af1c28a..aeeb2a6b0abedc5 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -1,5 +1,44 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
 
+// CHECK-DAG: memref.global &quot;private&quot; @[[ALLOC0:alloc.*]] : memref&lt;2x32xf32&gt;
+// CHECK-DAG: memref.global &quot;private&quot; @[[ALLOC1:alloc.*]] : memref&lt;2x32xf32&gt;
+
+// CHECK: func.func @func(
+func.func @func(%arg0: f32) {
+  %c3 = arith.constant 3 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: scf.forall
+  scf.forall (%arg1, %arg2) in (%c3, %c1) {
+    // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref&lt;2x32xf32&gt;
+    // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref&lt;2x32xf32&gt;
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref&lt;2x32xf32&gt;
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref&lt;2x32xf32&gt;
+    %alloca = memref.alloca() : memref&lt;2x32xf32&gt;
+    %alloca_0 = memref.alloca() : memref&lt;2x32xf32&gt;
+    memref.store %arg0, %alloca[%arg1, %arg2] : memref&lt;2x32xf32&gt;
+    memref.store %arg0, %alloca_0[%arg1, %arg2] : memref&lt;2x32xf32&gt;
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %alloca = transform.structured.match ops{[&quot;memref.alloca&quot;]} in %arg0
+      : (!transform.any_op) -&gt; !transform.any_op
+  %module = transform.structured.match ops{[&quot;builtin.module&quot;]} in %arg0
+      : (!transform.any_op) -&gt; !transform.any_op
+  %alloca_typed = transform.cast %alloca
+      : !transform.any_op to !transform.op&lt;&quot;memref.alloca&quot;&gt;
+  %module_typed = transform.cast %module
+      : !transform.any_op to !transform.op&lt;&quot;builtin.module&quot;&gt;
+  %get_global, %global =
+      transform.memref.alloca_to_global %alloca_typed in %module_typed
+        : (!transform.op&lt;&quot;builtin.module&quot;&gt;, !transform.op&lt;&quot;memref.alloca&quot;&gt;)
+          -&gt; (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map&lt;(d0) -&gt; ((d0 floordiv 4) mod 2)&gt;
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map&lt;(d0)[s0] -&gt; (d0 + s0)&gt;
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 68e3a4851539690..91a283c799941bb 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) {
   test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
 }
 
+
+// -----
+
+// expected-note @below {{target op}}
+module {
+  transform.sequence  failures(propagate) {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-error @below{{could not find a parent op that matches all requirements}}
+    %3 = get_parent_op %arg0 {op_name = &quot;builtin.module&quot;} : (!pdl.operation) -&gt; !transform.any_op
+  }
+}
+
 // -----
 
 func.func @cast(%arg0: f32) -&gt; f64 {
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f89005cb2f86d1b..8278019bbab3b89 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -16,6 +16,54 @@ def run(f):
     return f
 
 
+ at run
+def testMemRefAllocaToAllocOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get(&quot;memref.alloc&quot;),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get(&quot;builtin.module&quot;), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get(&quot;memref.alloca&quot;), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(module, alloca)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: (!transform.op&lt;&quot;builtin.module&quot;&gt;, !transform.op&lt;&quot;memref.alloca&quot;&gt;)
+    # CHECK-SAME: -&gt; (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMemRefAllocaToAllocOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get(&quot;memref.alloc&quot;),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get(&quot;builtin.module&quot;), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get(&quot;memref.alloca&quot;), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(
+            transform.OperationType.get(&quot;memref.get_global&quot;),
+            transform.OperationType.get(&quot;memref.global&quot;),
+            module,
+            alloca,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: -&gt; (!transform.op&lt;&quot;memref.get_global&quot;&gt;, !transform.op&lt;&quot;memref.global&quot;&gt;)
+
+
 @run
 def testMemRefMultiBufferOpCompact():
     sequence = transform.SequenceOp(
</pre>
</details>


https://github.com/llvm/llvm-project/pull/66511


More information about the Mlir-commits mailing list