[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)

Ingo Müller llvmlistbot at llvm.org
Thu Sep 21 07:40:21 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/66511

>From 594b498f5e03c00ed27de42b1ccf57af781c0f22 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Wed, 13 Sep 2023 12:27:31 +0000
Subject: [PATCH 1/4] [mlir][memref][transform] Add new alloca_to_global op.

This PR adds a new transform op that replaces `memref.alloca`s with
`memref.get_global`s to newly inserted `memref.global`s. This is useful,
for example, for allocations that should reside in the shared memory of
a GPU, which have to be declared as globals.
---
 .../MemRef/TransformOps/MemRefTransformOps.td | 65 ++++++++++++++
 .../TransformOps/MemRefTransformOps.cpp       | 90 +++++++++++++++++++
 .../dialects/_memref_transform_ops_ext.py     | 58 ++++++++++++
 mlir/test/Dialect/MemRef/transform-ops.mlir   | 39 ++++++++
 .../python/dialects/transform_memref_ext.py   | 48 ++++++++++
 5 files changed, 300 insertions(+)

diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 681759f970cb910..6a78784d74dd53c 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
 }
 
 def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">;
+
+def MemRefAllocaToGlobalOp :
+  Op<Transform_Dialect, "memref.alloca_to_global",
+     [TransformOpInterface,
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let description = [{
+    Inserts a new `memref.global` for each provided `memref.alloca` into the
+    provided module and replaces it with a `memref.get_global`. This is useful,
+    for example, for allocations that should reside in the shared memory of
+    a GPU, which have to be declared as globals.
+
+    #### Example
+
+    Consider the following transform op:
+
+    ```mlir
+    %get_global, %global =
+        transform.memref.alloca_to_global %alloca in %module
+          : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+            -> (!transform.any_op, !transform.any_op)
+    ```
+
+    and the following input payload:
+
+    ```mlir
+    module {
+      func.func @func() {
+        %alloca = memref.alloca() : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    then applying the transform op to the payload would result in the following
+    output IR:
+
+    ```mlir
+    module {
+      memref.global "private" @alloc : memref<2x32xf32>
+      func.func @func() {
+        %alloca = memref.get_global @alloc : memref<2x32xf32>
+        // usages of %alloca...
+      }
+    }
+    ```
+
+    #### Return modes
+
+    Emits a definite failure if not exactly one `module` payload op was provided
+    or any of the `alloca` payload ops is not inside that module, and succeeds
+    otherwise. The returned handles refer to the `memref.get_global` and
+    `memref.global` ops that were inserted by the transformation.
+  }];
+
+  let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module,
+                   Transform_MemRefAllocaOp:$alloca);
+  let results = (outs TransformHandleTypeInterface:$get_global,
+                  TransformHandleTypeInterface:$global);
+
+  let assemblyFormat = [{
+    $alloca `in` $module attr-dict `:` functional-type(operands, results)
+  }];
+}
 
 def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 58f4d8d8f6d21fe..7467359da83c37f 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
 }
 
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+                                             ModuleOp module) {
+  llvm::SmallString<64> candidateNameStorage;
+  StringRef candidateName(prefix);
+  int uniqueNumber = 0;
+  while (true) {
+    if (!module.lookupSymbol(candidateName)) {
+      break;
+    }
+    candidateNameStorage.clear();
+    candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+                        .toStringRef(candidateNameStorage);
+    uniqueNumber++;
+  }
+  return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+                                         transform::TransformResults &results,
+                                         transform::TransformState &state) {
+  auto allocaOps = state.getPayloadOps(getAlloca());
+
+  SmallVector<memref::GlobalOp> globalOps;
+  SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+  // Get `builtin.module`.
+  auto moduleOps = state.getPayloadOps(getModule());
+  if (!llvm::hasSingleElement(moduleOps)) {
+    return emitDefiniteFailure()
+           << Twine("expected exactly one 'module' payload, but found ") +
+                  std::to_string(llvm::range_size(moduleOps));
+  }
+  ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+  // Transform `memref.alloca`s.
+  for (auto *op : allocaOps) {
+    auto alloca = cast<memref::AllocaOp>(op);
+    MLIRContext *ctx = rewriter.getContext();
+    Location loc = alloca->getLoc();
+
+    memref::GlobalOp globalOp;
+    {
+      // Insert a `memref.global` at the beginning of the module.
+      if (module != alloca->getParentOfType<ModuleOp>()) {
+        return emitDefiniteFailure()
+               << "expected 'alloca' payload to be inside 'module' payload";
+      }
+      IRRewriter::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+      Type resultType = alloca.getResult().getType();
+      llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+      // XXX: Add a better builder for this.
+      globalOp = rewriter.create<memref::GlobalOp>(
+          loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+          TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+    }
+
+    // Replace the `memref.alloca` with a `memref.get_global` accessing the
+    // global symbol inserted above.
+    rewriter.setInsertionPoint(alloca);
+    auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+        alloca, globalOp.getType(), globalOp.getName());
+
+    globalOps.push_back(globalOp);
+    getGlobalOps.push_back(getGlobalOp);
+  }
+
+  // Assemble results.
+  results.set(getGlobal().cast<OpResult>(), globalOps);
+  results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MemRefAllocaToGlobalOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getModule(), effects);
+  producesHandle(getGlobal(), effects);
+  producesHandle(getGetGlobal(), effects);
+  consumesHandle(getAlloca(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefMultiBufferOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 4afe8e7b887f68e..56dcfbe5655e9b6 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -11,6 +11,64 @@
 from typing import Optional, overload, Union
 
 
+class MemRefAllocaToGlobalOp:
+    """Specialization for MemRefAllocaToGlobalOp class."""
+
+    @overload
+    def __init__(
+        self,
+        get_global_type: Type,
+        global_type: Type,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        module: Union[Operation, OpView, Value],
+        alloca: Union[Operation, OpView, Value],
+        *,
+        loc=None,
+        ip=None
+    ):
+        ...
+
+    def __init__(
+        self,
+        get_global_type_or_module: Union[Operation, OpView, Type, Value],
+        global_type_or_alloca: Union[Operation, OpView, Type, Value],
+        module_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        *,
+        loc=None,
+        ip=None
+    ):
+        if isinstance(get_global_type_or_module, Type):
+            get_global_type = get_global_type_or_module
+            global_type = global_type_or_alloca
+            module = module_or_none
+            alloca = alloca_or_none
+        else:
+            get_global_type = transform.AnyOpType.get()
+            global_type = transform.AnyOpType.get()
+            module = get_global_type_or_module
+            alloca = global_type_or_alloca
+
+        super().__init__(
+            get_global_type,
+            global_type,
+            module,
+            alloca,
+            loc=loc,
+            ip=ip,
+        )
+
+
 class MemRefMultiBufferOp:
     """Specialization for MemRefMultiBufferOp class."""
 
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index b19db447af1c28a..aeeb2a6b0abedc5 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -1,5 +1,44 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
 
+// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
+// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
+
+// CHECK: func.func @func(
+func.func @func(%arg0: f32) {
+  %c3 = arith.constant 3 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: scf.forall
+  scf.forall (%arg1, %arg2) in (%c3, %c1) {
+    // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
+    // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+    %alloca = memref.alloca() : memref<2x32xf32>
+    %alloca_0 = memref.alloca() : memref<2x32xf32>
+    memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32>
+    memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32>
+  }
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+  %module = transform.structured.match ops{["builtin.module"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+  %alloca_typed = transform.cast %alloca
+      : !transform.any_op to !transform.op<"memref.alloca">
+  %module_typed = transform.cast %module
+      : !transform.any_op to !transform.op<"builtin.module">
+  %get_global, %global =
+      transform.memref.alloca_to_global %alloca_typed in %module_typed
+        : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+          -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
 
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f89005cb2f86d1b..8278019bbab3b89 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -16,6 +16,54 @@ def run(f):
     return f
 
 
+ at run
+def testMemRefAllocaToAllocOpCompact():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get("builtin.module"), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(module, alloca)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+    # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMemRefAllocaToAllocOpTyped():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("memref.alloc"),
+    )
+    with InsertionPoint(sequence.body):
+        module = transform.CastOp(
+            transform.OperationType.get("builtin.module"), sequence.bodyTarget
+        )
+        alloca = transform.CastOp(
+            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+        )
+        memref.MemRefAllocaToGlobalOp(
+            transform.OperationType.get("memref.get_global"),
+            transform.OperationType.get("memref.global"),
+            module,
+            alloca,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
+    # CHECK: = transform.memref.alloca_to_global
+    # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
+
+
 @run
 def testMemRefMultiBufferOpCompact():
     sequence = transform.SequenceOp(

>From b8f439e2fe4ba708beb906c9cac76f50845f94a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 09:21:58 +0000
Subject: [PATCH 2/4] Address comments from @ftynse's review.

In particular:
* Accept any op type with `SymbolTable` trait as containing op rather
  than only `builtin.module` and rename op argument accordingly.
* Use `SymbolTable::insert` to unique the name of the globals rather
  than some hand-rolled function.
* Use more sane semantics in Python mix-in test.
---
 .../MemRef/TransformOps/MemRefTransformOps.td | 22 ++++---
 .../TransformOps/MemRefTransformOps.cpp       | 63 ++++++++-----------
 mlir/test/Dialect/MemRef/transform-ops.mlir   | 19 +++---
 .../python/dialects/transform_memref_ext.py   | 20 ++----
 4 files changed, 54 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 6a78784d74dd53c..af2401a80b898b0 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -153,9 +153,10 @@ def MemRefAllocaToGlobalOp :
       DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let description = [{
     Inserts a new `memref.global` for each provided `memref.alloca` into the
-    provided module and replaces it with a `memref.get_global`. This is useful,
-    for example, for allocations that should reside in the shared memory of
-    a GPU, which have to be declared as globals.
+    provided symbol table (e.g., a `builtin.module`) and replaces it with a
+    `memref.get_global`. This is useful, for example, for allocations that
+    should reside in the shared memory of a GPU, which have to be declared as
+    globals.
 
     #### Example
 
@@ -194,19 +195,20 @@ def MemRefAllocaToGlobalOp :
 
     #### Return modes
 
-    Emits a definite failure if not exactly one `module` payload op was provided
-    or any of the `alloca` payload ops is not inside that module, and succeeds
-    otherwise. The returned handles refer to the `memref.get_global` and
-    `memref.global` ops that were inserted by the transformation.
+    Emits a definite failure if not exactly one symbol table payload op was
+    provided or any of the `alloca` payload ops is not inside that symbol table
+    op, and succeeds otherwise. The returned handles refer to the
+    `memref.get_global` and `memref.global` ops that were inserted by the
+    transformation.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module,
+  let arguments = (ins TransformHandleTypeInterface:$symbolTable,
                    Transform_MemRefAllocaOp:$alloca);
-  let results = (outs TransformHandleTypeInterface:$get_global,
+  let results = (outs TransformHandleTypeInterface:$getGlobal,
                   TransformHandleTypeInterface:$global);
 
   let assemblyFormat = [{
-    $alloca `in` $module attr-dict `:` functional-type(operands, results)
+    $alloca `in` $symbolTable attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 7467359da83c37f..e5e19b4edbc5a85 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -130,25 +130,6 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
 // AllocaToGlobalOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
-                                             ModuleOp module) {
-  llvm::SmallString<64> candidateNameStorage;
-  StringRef candidateName(prefix);
-  int uniqueNumber = 0;
-  while (true) {
-    if (!module.lookupSymbol(candidateName)) {
-      break;
-    }
-    candidateNameStorage.clear();
-    candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
-                        .toStringRef(candidateNameStorage);
-    uniqueNumber++;
-  }
-  return candidateName;
-}
-} // namespace
-
 DiagnosedSilenceableFailure
 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
                                          transform::TransformResults &results,
@@ -158,14 +139,25 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<memref::GlobalOp> globalOps;
   SmallVector<memref::GetGlobalOp> getGlobalOps;
 
-  // Get `builtin.module`.
-  auto moduleOps = state.getPayloadOps(getModule());
-  if (!llvm::hasSingleElement(moduleOps)) {
+  // Get containing symbol table op.
+  auto symbolTableOps = state.getPayloadOps(getSymbolTable());
+  if (!llvm::hasSingleElement(symbolTableOps)) {
     return emitDefiniteFailure()
-           << Twine("expected exactly one 'module' payload, but found ") +
-                  std::to_string(llvm::range_size(moduleOps));
+           << Twine("expected exactly one 'symbolTable' payload, but found ") +
+                  std::to_string(llvm::range_size(symbolTableOps));
+  }
+  Operation *symbolTableOp = *symbolTableOps.begin();
+  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
+    return emitDefiniteFailure() << Twine(
+               "expected 'symbolTable' payload to have 'SymbolTable' trait");
+  }
+  SymbolTable symbolTable(symbolTableOp);
+
+  {
+    size_t numAllocaOps = llvm::range_size(allocaOps);
+    globalOps.reserve(numAllocaOps);
+    getGlobalOps.reserve(numAllocaOps);
   }
-  ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
 
   // Transform `memref.alloca`s.
   for (auto *op : allocaOps) {
@@ -175,19 +167,18 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
 
     memref::GlobalOp globalOp;
     {
-      // Insert a `memref.global` at the beginning of the module.
-      if (module != alloca->getParentOfType<ModuleOp>()) {
-        return emitDefiniteFailure()
-               << "expected 'alloca' payload to be inside 'module' payload";
+      // Insert a `memref.global` into the symbol table.
+      if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) {
+        return emitDefiniteFailure() << "expected 'alloca' payload to be "
+                                        "inside 'symbolTable' payload";
       }
-      IRRewriter::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
       Type resultType = alloca.getResult().getType();
-      llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
-      // XXX: Add a better builder for this.
-      globalOp = rewriter.create<memref::GlobalOp>(
-          loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+      // TODO: Add a better builder for this.
+      OpBuilder builder(rewriter.getContext());
+      globalOp = builder.create<memref::GlobalOp>(
+          loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
           TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+      symbolTable.insert(globalOp);
     }
 
     // Replace the `memref.alloca` with a `memref.get_global` accessing the
@@ -209,7 +200,7 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
 
 void transform::MemRefAllocaToGlobalOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  onlyReadsHandle(getModule(), effects);
+  onlyReadsHandle(getSymbolTable(), effects);
   producesHandle(getGlobal(), effects);
   producesHandle(getGetGlobal(), effects);
   consumesHandle(getAlloca(), effects);
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index aeeb2a6b0abedc5..e22e3d62190c445 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -3,20 +3,19 @@
 // CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
 // CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
 
-// CHECK: func.func @func(
-func.func @func(%arg0: f32) {
-  %c3 = arith.constant 3 : index
-  %c1 = arith.constant 1 : index
-  // CHECK: scf.forall
-  scf.forall (%arg1, %arg2) in (%c3, %c1) {
+// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index)
+func.func @func(%lb: index, %ub: index) {
+  // CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]])
+  scf.forall (%arg0, %arg1) in (%lb, %ub) {
     // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
     // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
     // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
     // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
-    %alloca = memref.alloca() : memref<2x32xf32>
-    %alloca_0 = memref.alloca() : memref<2x32xf32>
-    memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32>
-    memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32>
+    %cst = arith.constant 0.0 : f32
+    %mr0 = memref.alloca() : memref<2x32xf32>
+    %mr1 = memref.alloca() : memref<2x32xf32>
+    memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32>
+    memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32>
   }
   return
 }
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index 8278019bbab3b89..6f622cbde70859e 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -21,15 +21,11 @@ def testMemRefAllocaToAllocOpCompact():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
-        transform.OperationType.get("memref.alloc"),
+        transform.OperationType.get("builtin.module"),
+        [transform.OperationType.get("memref.alloca")],
     )
     with InsertionPoint(sequence.body):
-        module = transform.CastOp(
-            transform.OperationType.get("builtin.module"), sequence.bodyTarget
-        )
-        alloca = transform.CastOp(
-            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
-        )
+        module, alloca = sequence.body.arguments
         memref.MemRefAllocaToGlobalOp(module, alloca)
         transform.YieldOp()
     # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
@@ -43,15 +39,11 @@ def testMemRefAllocaToAllocOpTyped():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
-        transform.OperationType.get("memref.alloc"),
+        transform.OperationType.get("builtin.module"),
+        [transform.OperationType.get("memref.alloca")],
     )
     with InsertionPoint(sequence.body):
-        module = transform.CastOp(
-            transform.OperationType.get("builtin.module"), sequence.bodyTarget
-        )
-        alloca = transform.CastOp(
-            transform.OperationType.get("memref.alloca"), sequence.bodyTarget
-        )
+        module, alloca = sequence.body.arguments
         memref.MemRefAllocaToGlobalOp(
             transform.OperationType.get("memref.get_global"),
             transform.OperationType.get("memref.global"),

>From a91c93ddcf4535405644cdc477229a5f982892ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 13:37:07 +0000
Subject: [PATCH 3/4] Remove symbolTable op arg and simplify tests.

---
 .../MemRef/TransformOps/MemRefTransformOps.td | 18 ++++-------
 .../TransformOps/MemRefTransformOps.cpp       | 32 ++++---------------
 mlir/test/Dialect/MemRef/transform-ops.mlir   | 13 ++------
 3 files changed, 16 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index af2401a80b898b0..d7bd8410e360a76 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -153,7 +153,7 @@ def MemRefAllocaToGlobalOp :
       DeclareOpInterfaceMethods<TransformOpInterface>]> {
   let description = [{
     Inserts a new `memref.global` for each provided `memref.alloca` into the
-    provided symbol table (e.g., a `builtin.module`) and replaces it with a
+    nearest symbol table (e.g., a `builtin.module`) and replaces it with a
     `memref.get_global`. This is useful, for example, for allocations that
     should reside in the shared memory of a GPU, which have to be declared as
     globals.
@@ -164,8 +164,8 @@ def MemRefAllocaToGlobalOp :
 
     ```mlir
     %get_global, %global =
-        transform.memref.alloca_to_global %alloca in %module
-          : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+        transform.memref.alloca_to_global %alloca
+          : (!transform.op<"memref.alloca">)
             -> (!transform.any_op, !transform.any_op)
     ```
 
@@ -195,20 +195,16 @@ def MemRefAllocaToGlobalOp :
 
     #### Return modes
 
-    Emits a definite failure if not exactly one symbol table payload op was
-    provided or any of the `alloca` payload ops is not inside that symbol table
-    op, and succeeds otherwise. The returned handles refer to the
-    `memref.get_global` and `memref.global` ops that were inserted by the
-    transformation.
+    Succeeds always. The returned handles refer to the `memref.get_global` and
+    `memref.global` ops that were inserted by the transformation.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$symbolTable,
-                   Transform_MemRefAllocaOp:$alloca);
+  let arguments = (ins Transform_MemRefAllocaOp:$alloca);
   let results = (outs TransformHandleTypeInterface:$getGlobal,
                   TransformHandleTypeInterface:$global);
 
   let assemblyFormat = [{
-    $alloca `in` $symbolTable attr-dict `:` functional-type(operands, results)
+    $alloca attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index e5e19b4edbc5a85..eed29efcaaada88 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -139,26 +139,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<memref::GlobalOp> globalOps;
   SmallVector<memref::GetGlobalOp> getGlobalOps;
 
-  // Get containing symbol table op.
-  auto symbolTableOps = state.getPayloadOps(getSymbolTable());
-  if (!llvm::hasSingleElement(symbolTableOps)) {
-    return emitDefiniteFailure()
-           << Twine("expected exactly one 'symbolTable' payload, but found ") +
-                  std::to_string(llvm::range_size(symbolTableOps));
-  }
-  Operation *symbolTableOp = *symbolTableOps.begin();
-  if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
-    return emitDefiniteFailure() << Twine(
-               "expected 'symbolTable' payload to have 'SymbolTable' trait");
-  }
-  SymbolTable symbolTable(symbolTableOp);
-
-  {
-    size_t numAllocaOps = llvm::range_size(allocaOps);
-    globalOps.reserve(numAllocaOps);
-    getGlobalOps.reserve(numAllocaOps);
-  }
-
   // Transform `memref.alloca`s.
   for (auto *op : allocaOps) {
     auto alloca = cast<memref::AllocaOp>(op);
@@ -167,14 +147,15 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
 
     memref::GlobalOp globalOp;
     {
+      // Find nearest symbol table.
+      Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+      assert(symbolTableOp && "expected alloca payload to be in symbol table");
+      SymbolTable symbolTable(symbolTableOp);
+
       // Insert a `memref.global` into the symbol table.
-      if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) {
-        return emitDefiniteFailure() << "expected 'alloca' payload to be "
-                                        "inside 'symbolTable' payload";
-      }
       Type resultType = alloca.getResult().getType();
-      // TODO: Add a better builder for this.
       OpBuilder builder(rewriter.getContext());
+      // TODO: Add a better builder for this.
       globalOp = builder.create<memref::GlobalOp>(
           loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
           TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
@@ -200,7 +181,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
 
 void transform::MemRefAllocaToGlobalOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  onlyReadsHandle(getSymbolTable(), effects);
   producesHandle(getGlobal(), effects);
   producesHandle(getGetGlobal(), effects);
   consumesHandle(getAlloca(), effects);
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index e22e3d62190c445..68fea1f8402955c 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -23,16 +23,9 @@ func.func @func(%lb: index, %ub: index) {
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !transform.any_op):
   %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
-      : (!transform.any_op) -> !transform.any_op
-  %module = transform.structured.match ops{["builtin.module"]} in %arg0
-      : (!transform.any_op) -> !transform.any_op
-  %alloca_typed = transform.cast %alloca
-      : !transform.any_op to !transform.op<"memref.alloca">
-  %module_typed = transform.cast %module
-      : !transform.any_op to !transform.op<"builtin.module">
-  %get_global, %global =
-      transform.memref.alloca_to_global %alloca_typed in %module_typed
-        : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+      : (!transform.any_op) -> !transform.op<"memref.alloca">
+  %get_global, %global = transform.memref.alloca_to_global %alloca
+        : (!transform.op<"memref.alloca">)
           -> (!transform.any_op, !transform.any_op)
 }
 

>From 05598559cf8e2c4a83ab77f74d84b9de3e9edfa6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 14:40:07 +0000
Subject: [PATCH 4/4] Remove symbolTable op arg in Python mix-in.

---
 .../dialects/_memref_transform_ops_ext.py     | 26 +++++--------------
 .../python/dialects/transform_memref_ext.py   | 15 ++++-------
 2 files changed, 12 insertions(+), 29 deletions(-)

diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 56dcfbe5655e9b6..1cc00bdcbf381c9 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -19,7 +19,6 @@ def __init__(
         self,
         get_global_type: Type,
         global_type: Type,
-        module: Union[Operation, OpView, Value],
         alloca: Union[Operation, OpView, Value],
         *,
         loc=None,
@@ -28,41 +27,30 @@ def __init__(
         ...
 
     @overload
-    def __init__(
-        self,
-        module: Union[Operation, OpView, Value],
-        alloca: Union[Operation, OpView, Value],
-        *,
-        loc=None,
-        ip=None
-    ):
+    def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
         ...
 
     def __init__(
         self,
-        get_global_type_or_module: Union[Operation, OpView, Type, Value],
-        global_type_or_alloca: Union[Operation, OpView, Type, Value],
-        module_or_none: Optional[Union[Operation, OpView, Value]] = None,
+        get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
+        global_type_or_none: Optional[Type] = None,
         alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
         *,
         loc=None,
         ip=None
     ):
-        if isinstance(get_global_type_or_module, Type):
-            get_global_type = get_global_type_or_module
-            global_type = global_type_or_alloca
-            module = module_or_none
+        if isinstance(get_global_type_or_alloca, Type):
+            get_global_type = get_global_type_or_alloca
+            global_type = global_type_or_none
             alloca = alloca_or_none
         else:
             get_global_type = transform.AnyOpType.get()
             global_type = transform.AnyOpType.get()
-            module = get_global_type_or_module
-            alloca = global_type_or_alloca
+            alloca = get_global_type_or_alloca
 
         super().__init__(
             get_global_type,
             global_type,
-            module,
             alloca,
             loc=loc,
             ip=ip,
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index 6f622cbde70859e..e7d871c9eac8c7f 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -21,16 +21,14 @@ def testMemRefAllocaToAllocOpCompact():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
-        transform.OperationType.get("builtin.module"),
-        [transform.OperationType.get("memref.alloca")],
+        transform.OperationType.get("memref.alloca"),
     )
     with InsertionPoint(sequence.body):
-        module, alloca = sequence.body.arguments
-        memref.MemRefAllocaToGlobalOp(module, alloca)
+        memref.MemRefAllocaToGlobalOp(sequence.bodyTarget)
         transform.YieldOp()
     # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
     # CHECK: = transform.memref.alloca_to_global
-    # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+    # CHECK-SAME: (!transform.op<"memref.alloca">)
     # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
 
 
@@ -39,16 +37,13 @@ def testMemRefAllocaToAllocOpTyped():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
-        transform.OperationType.get("builtin.module"),
-        [transform.OperationType.get("memref.alloca")],
+        transform.OperationType.get("memref.alloca"),
     )
     with InsertionPoint(sequence.body):
-        module, alloca = sequence.body.arguments
         memref.MemRefAllocaToGlobalOp(
             transform.OperationType.get("memref.get_global"),
             transform.OperationType.get("memref.global"),
-            module,
-            alloca,
+            sequence.bodyTarget,
         )
         transform.YieldOp()
     # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped



More information about the Mlir-commits mailing list