[Mlir-commits] [mlir] [mlir][memref][transform] Add new alloca_to_global op. (PR #66511)
Ingo Müller
llvmlistbot at llvm.org
Thu Sep 21 03:13:45 PDT 2023
https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/66511
>From 325bc7827d04ca76069aed08133b1497877642eb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Wed, 13 Sep 2023 12:27:31 +0000
Subject: [PATCH 1/2] [mlir][memref][transform] Add new alloca_to_global op.
This PR adds a new transform op that replaces `memref.alloca`s with
`memref.get_global`s to newly inserted `memref.global`s. This is useful,
for example, for allocations that should reside in the shared memory of
a GPU, which have to be declared as globals.
---
.../MemRef/TransformOps/MemRefTransformOps.td | 65 ++++++++++++++
.../TransformOps/MemRefTransformOps.cpp | 90 +++++++++++++++++++
.../dialects/_memref_transform_ops_ext.py | 58 ++++++++++++
mlir/test/Dialect/MemRef/transform-ops.mlir | 39 ++++++++
.../python/dialects/transform_memref_ext.py | 48 ++++++++++
5 files changed, 300 insertions(+)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 681759f970cb910..6a78784d74dd53c 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect,
}
def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">;
+
+def MemRefAllocaToGlobalOp :
+ Op<Transform_Dialect, "memref.alloca_to_global",
+ [TransformOpInterface,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let description = [{
+ Inserts a new `memref.global` for each provided `memref.alloca` into the
+ provided module and replaces it with a `memref.get_global`. This is useful,
+ for example, for allocations that should reside in the shared memory of
+ a GPU, which have to be declared as globals.
+
+ #### Example
+
+ Consider the following transform op:
+
+ ```mlir
+ %get_global, %global =
+ transform.memref.alloca_to_global %alloca in %module
+ : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+ -> (!transform.any_op, !transform.any_op)
+ ```
+
+ and the following input payload:
+
+ ```mlir
+ module {
+ func.func @func() {
+ %alloca = memref.alloca() : memref<2x32xf32>
+ // usages of %alloca...
+ }
+ }
+ ```
+
+ then applying the transform op to the payload would result in the following
+ output IR:
+
+ ```mlir
+ module {
+ memref.global "private" @alloc : memref<2x32xf32>
+ func.func @func() {
+ %alloca = memref.get_global @alloc : memref<2x32xf32>
+ // usages of %alloca...
+ }
+ }
+ ```
+
+ #### Return modes
+
+ Emits a definite failure if not exactly one `module` payload op was provided
+ or any of the `alloca` payload ops is not inside that module, and succeeds
+ otherwise. The returned handles refer to the `memref.get_global` and
+ `memref.global` ops that were inserted by the transformation.
+ }];
+
+ let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module,
+ Transform_MemRefAllocaOp:$alloca);
+ let results = (outs TransformHandleTypeInterface:$get_global,
+ TransformHandleTypeInterface:$global);
+
+ let assemblyFormat = [{
+ $alloca `in` $module attr-dict `:` functional-type(operands, results)
+ }];
+}
def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 58f4d8d8f6d21fe..7467359da83c37f 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
+//===----------------------------------------------------------------------===//
+// AllocaToGlobalOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
+ ModuleOp module) {
+ llvm::SmallString<64> candidateNameStorage;
+ StringRef candidateName(prefix);
+ int uniqueNumber = 0;
+ while (true) {
+ if (!module.lookupSymbol(candidateName)) {
+ break;
+ }
+ candidateNameStorage.clear();
+ candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
+ .toStringRef(candidateNameStorage);
+ uniqueNumber++;
+ }
+ return candidateName;
+}
+} // namespace
+
+DiagnosedSilenceableFailure
+transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto allocaOps = state.getPayloadOps(getAlloca());
+
+ SmallVector<memref::GlobalOp> globalOps;
+ SmallVector<memref::GetGlobalOp> getGlobalOps;
+
+ // Get `builtin.module`.
+ auto moduleOps = state.getPayloadOps(getModule());
+ if (!llvm::hasSingleElement(moduleOps)) {
+ return emitDefiniteFailure()
+ << Twine("expected exactly one 'module' payload, but found ") +
+ std::to_string(llvm::range_size(moduleOps));
+ }
+ ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
+
+ // Transform `memref.alloca`s.
+ for (auto *op : allocaOps) {
+ auto alloca = cast<memref::AllocaOp>(op);
+ MLIRContext *ctx = rewriter.getContext();
+ Location loc = alloca->getLoc();
+
+ memref::GlobalOp globalOp;
+ {
+ // Insert a `memref.global` at the beginning of the module.
+ if (module != alloca->getParentOfType<ModuleOp>()) {
+ return emitDefiniteFailure()
+ << "expected 'alloca' payload to be inside 'module' payload";
+ }
+ IRRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
+ Type resultType = alloca.getResult().getType();
+ llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
+ // XXX: Add a better builder for this.
+ globalOp = rewriter.create<memref::GlobalOp>(
+ loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+ TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+ }
+
+ // Replace the `memref.alloca` with a `memref.get_global` accessing the
+ // global symbol inserted above.
+ rewriter.setInsertionPoint(alloca);
+ auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
+ alloca, globalOp.getType(), globalOp.getName());
+
+ globalOps.push_back(globalOp);
+ getGlobalOps.push_back(getGlobalOp);
+ }
+
+ // Assemble results.
+ results.set(getGlobal().cast<OpResult>(), globalOps);
+ results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MemRefAllocaToGlobalOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getModule(), effects);
+ producesHandle(getGlobal(), effects);
+ producesHandle(getGetGlobal(), effects);
+ consumesHandle(getAlloca(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
index 4afe8e7b887f68e..56dcfbe5655e9b6 100644
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
@@ -11,6 +11,64 @@
from typing import Optional, overload, Union
+class MemRefAllocaToGlobalOp:
+ """Specialization for MemRefAllocaToGlobalOp class."""
+
+ @overload
+ def __init__(
+ self,
+ get_global_type: Type,
+ global_type: Type,
+ module: Union[Operation, OpView, Value],
+ alloca: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ module: Union[Operation, OpView, Value],
+ alloca: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None
+ ):
+ ...
+
+ def __init__(
+ self,
+ get_global_type_or_module: Union[Operation, OpView, Type, Value],
+ global_type_or_alloca: Union[Operation, OpView, Type, Value],
+ module_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None
+ ):
+ if isinstance(get_global_type_or_module, Type):
+ get_global_type = get_global_type_or_module
+ global_type = global_type_or_alloca
+ module = module_or_none
+ alloca = alloca_or_none
+ else:
+ get_global_type = transform.AnyOpType.get()
+ global_type = transform.AnyOpType.get()
+ module = get_global_type_or_module
+ alloca = global_type_or_alloca
+
+ super().__init__(
+ get_global_type,
+ global_type,
+ module,
+ alloca,
+ loc=loc,
+ ip=ip,
+ )
+
+
class MemRefMultiBufferOp:
"""Specialization for MemRefMultiBufferOp class."""
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index b19db447af1c28a..aeeb2a6b0abedc5 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -1,5 +1,44 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s
+// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
+// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
+
+// CHECK: func.func @func(
+func.func @func(%arg0: f32) {
+ %c3 = arith.constant 3 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: scf.forall
+ scf.forall (%arg1, %arg2) in (%c3, %c1) {
+ // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
+ // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
+ // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+ // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
+ %alloca = memref.alloca() : memref<2x32xf32>
+ %alloca_0 = memref.alloca() : memref<2x32xf32>
+ memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32>
+ memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32>
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+ %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %module = transform.structured.match ops{["builtin.module"]} in %arg0
+ : (!transform.any_op) -> !transform.any_op
+ %alloca_typed = transform.cast %alloca
+ : !transform.any_op to !transform.op<"memref.alloca">
+ %module_typed = transform.cast %module
+ : !transform.any_op to !transform.op<"builtin.module">
+ %get_global, %global =
+ transform.memref.alloca_to_global %alloca_typed in %module_typed
+ : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+ -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index f89005cb2f86d1b..8278019bbab3b89 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -16,6 +16,54 @@ def run(f):
return f
+ at run
+def testMemRefAllocaToAllocOpCompact():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("memref.alloc"),
+ )
+ with InsertionPoint(sequence.body):
+ module = transform.CastOp(
+ transform.OperationType.get("builtin.module"), sequence.bodyTarget
+ )
+ alloca = transform.CastOp(
+ transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+ )
+ memref.MemRefAllocaToGlobalOp(module, alloca)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
+ # CHECK: = transform.memref.alloca_to_global
+ # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
+ # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
+
+
+ at run
+def testMemRefAllocaToAllocOpTyped():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("memref.alloc"),
+ )
+ with InsertionPoint(sequence.body):
+ module = transform.CastOp(
+ transform.OperationType.get("builtin.module"), sequence.bodyTarget
+ )
+ alloca = transform.CastOp(
+ transform.OperationType.get("memref.alloca"), sequence.bodyTarget
+ )
+ memref.MemRefAllocaToGlobalOp(
+ transform.OperationType.get("memref.get_global"),
+ transform.OperationType.get("memref.global"),
+ module,
+ alloca,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
+ # CHECK: = transform.memref.alloca_to_global
+ # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
+
+
@run
def testMemRefMultiBufferOpCompact():
sequence = transform.SequenceOp(
>From 8ee31ebe3639ec04159cf682fec1bb9ac228da3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 21 Sep 2023 09:21:58 +0000
Subject: [PATCH 2/2] Address comments from @ftynse's review.
In particular:
* Accept any op type with `SymbolTable` trait as containing op rather
than only `builtin.module` and rename op argument accordingly.
* Use `SymbolTable::insert` to unique the name of the globals rather
than some hand-rolled function.
* Use more sane semantics in Python mix-in test.
---
.../MemRef/TransformOps/MemRefTransformOps.td | 22 ++++---
.../TransformOps/MemRefTransformOps.cpp | 63 ++++++++-----------
mlir/test/Dialect/MemRef/transform-ops.mlir | 19 +++---
.../python/dialects/transform_memref_ext.py | 20 ++----
4 files changed, 54 insertions(+), 70 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index 6a78784d74dd53c..af2401a80b898b0 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -153,9 +153,10 @@ def MemRefAllocaToGlobalOp :
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let description = [{
Inserts a new `memref.global` for each provided `memref.alloca` into the
- provided module and replaces it with a `memref.get_global`. This is useful,
- for example, for allocations that should reside in the shared memory of
- a GPU, which have to be declared as globals.
+ provided symbol table (e.g., a `builtin.module`) and replaces it with a
+ `memref.get_global`. This is useful, for example, for allocations that
+ should reside in the shared memory of a GPU, which have to be declared as
+ globals.
#### Example
@@ -194,19 +195,20 @@ def MemRefAllocaToGlobalOp :
#### Return modes
- Emits a definite failure if not exactly one `module` payload op was provided
- or any of the `alloca` payload ops is not inside that module, and succeeds
- otherwise. The returned handles refer to the `memref.get_global` and
- `memref.global` ops that were inserted by the transformation.
+ Emits a definite failure if not exactly one symbol table payload op was
+ provided or any of the `alloca` payload ops is not inside that symbol table
+ op, and succeeds otherwise. The returned handles refer to the
+ `memref.get_global` and `memref.global` ops that were inserted by the
+ transformation.
}];
- let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module,
+ let arguments = (ins TransformHandleTypeInterface:$symbolTable,
Transform_MemRefAllocaOp:$alloca);
- let results = (outs TransformHandleTypeInterface:$get_global,
+ let results = (outs TransformHandleTypeInterface:$getGlobal,
TransformHandleTypeInterface:$global);
let assemblyFormat = [{
- $alloca `in` $module attr-dict `:` functional-type(operands, results)
+ $alloca `in` $symbolTable attr-dict `:` functional-type(operands, results)
}];
}
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 7467359da83c37f..e5e19b4edbc5a85 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -130,25 +130,6 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
// AllocaToGlobalOp
//===----------------------------------------------------------------------===//
-namespace {
-static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix,
- ModuleOp module) {
- llvm::SmallString<64> candidateNameStorage;
- StringRef candidateName(prefix);
- int uniqueNumber = 0;
- while (true) {
- if (!module.lookupSymbol(candidateName)) {
- break;
- }
- candidateNameStorage.clear();
- candidateName = (prefix + Twine("_") + Twine(uniqueNumber))
- .toStringRef(candidateNameStorage);
- uniqueNumber++;
- }
- return candidateName;
-}
-} // namespace
-
DiagnosedSilenceableFailure
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -158,14 +139,25 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
SmallVector<memref::GlobalOp> globalOps;
SmallVector<memref::GetGlobalOp> getGlobalOps;
- // Get `builtin.module`.
- auto moduleOps = state.getPayloadOps(getModule());
- if (!llvm::hasSingleElement(moduleOps)) {
+ // Get containing symbol table op.
+ auto symbolTableOps = state.getPayloadOps(getSymbolTable());
+ if (!llvm::hasSingleElement(symbolTableOps)) {
return emitDefiniteFailure()
- << Twine("expected exactly one 'module' payload, but found ") +
- std::to_string(llvm::range_size(moduleOps));
+ << Twine("expected exactly one 'symbolTable' payload, but found ") +
+ std::to_string(llvm::range_size(symbolTableOps));
+ }
+ Operation *symbolTableOp = *symbolTableOps.begin();
+ if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
+ return emitDefiniteFailure() << Twine(
+ "expected 'symbolTable' payload to have 'SymbolTable' trait");
+ }
+ SymbolTable symbolTable(symbolTableOp);
+
+ {
+ size_t numAllocaOps = llvm::range_size(allocaOps);
+ globalOps.reserve(numAllocaOps);
+ getGlobalOps.reserve(numAllocaOps);
}
- ModuleOp module = cast<ModuleOp>(*moduleOps.begin());
// Transform `memref.alloca`s.
for (auto *op : allocaOps) {
@@ -175,19 +167,18 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
memref::GlobalOp globalOp;
{
- // Insert a `memref.global` at the beginning of the module.
- if (module != alloca->getParentOfType<ModuleOp>()) {
- return emitDefiniteFailure()
- << "expected 'alloca' payload to be inside 'module' payload";
+ // Insert a `memref.global` into the symbol table.
+ if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) {
+ return emitDefiniteFailure() << "expected 'alloca' payload to be "
+ "inside 'symbolTable' payload";
}
- IRRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&module.getBodyRegion().front());
Type resultType = alloca.getResult().getType();
- llvm::SmallString<64> symName = getUniqueSymbol("alloca", module);
- // XXX: Add a better builder for this.
- globalOp = rewriter.create<memref::GlobalOp>(
- loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"),
+ // TODO: Add a better builder for this.
+ OpBuilder builder(rewriter.getContext());
+ globalOp = builder.create<memref::GlobalOp>(
+ loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
+ symbolTable.insert(globalOp);
}
// Replace the `memref.alloca` with a `memref.get_global` accessing the
@@ -209,7 +200,7 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
void transform::MemRefAllocaToGlobalOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- onlyReadsHandle(getModule(), effects);
+ onlyReadsHandle(getSymbolTable(), effects);
producesHandle(getGlobal(), effects);
producesHandle(getGetGlobal(), effects);
consumesHandle(getAlloca(), effects);
diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir
index aeeb2a6b0abedc5..e22e3d62190c445 100644
--- a/mlir/test/Dialect/MemRef/transform-ops.mlir
+++ b/mlir/test/Dialect/MemRef/transform-ops.mlir
@@ -3,20 +3,19 @@
// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32>
// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32>
-// CHECK: func.func @func(
-func.func @func(%arg0: f32) {
- %c3 = arith.constant 3 : index
- %c1 = arith.constant 1 : index
- // CHECK: scf.forall
- scf.forall (%arg1, %arg2) in (%c3, %c1) {
+// CHECK-DAG: func.func @func(%[[LB:.*]]: index, %[[UB:.*]]: index)
+func.func @func(%lb: index, %ub: index) {
+ // CHECK-DAG: scf.forall (%[[ARG0:.*]], %[[ARG1:.*]]) in (%[[LB]], %[[UB]])
+ scf.forall (%arg0, %arg1) in (%lb, %ub) {
// CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32>
// CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32>
// CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
// CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32>
- %alloca = memref.alloca() : memref<2x32xf32>
- %alloca_0 = memref.alloca() : memref<2x32xf32>
- memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32>
- memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32>
+ %cst = arith.constant 0.0 : f32
+ %mr0 = memref.alloca() : memref<2x32xf32>
+ %mr1 = memref.alloca() : memref<2x32xf32>
+ memref.store %cst, %mr0[%arg0, %arg1] : memref<2x32xf32>
+ memref.store %cst, %mr1[%arg0, %arg1] : memref<2x32xf32>
}
return
}
diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py
index 8278019bbab3b89..6f622cbde70859e 100644
--- a/mlir/test/python/dialects/transform_memref_ext.py
+++ b/mlir/test/python/dialects/transform_memref_ext.py
@@ -21,15 +21,11 @@ def testMemRefAllocaToAllocOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("memref.alloc"),
+ transform.OperationType.get("builtin.module"),
+ [transform.OperationType.get("memref.alloca")],
)
with InsertionPoint(sequence.body):
- module = transform.CastOp(
- transform.OperationType.get("builtin.module"), sequence.bodyTarget
- )
- alloca = transform.CastOp(
- transform.OperationType.get("memref.alloca"), sequence.bodyTarget
- )
+ module, alloca = sequence.body.arguments
memref.MemRefAllocaToGlobalOp(module, alloca)
transform.YieldOp()
# CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
@@ -43,15 +39,11 @@ def testMemRefAllocaToAllocOpTyped():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
- transform.OperationType.get("memref.alloc"),
+ transform.OperationType.get("builtin.module"),
+ [transform.OperationType.get("memref.alloca")],
)
with InsertionPoint(sequence.body):
- module = transform.CastOp(
- transform.OperationType.get("builtin.module"), sequence.bodyTarget
- )
- alloca = transform.CastOp(
- transform.OperationType.get("memref.alloca"), sequence.bodyTarget
- )
+ module, alloca = sequence.body.arguments
memref.MemRefAllocaToGlobalOp(
transform.OperationType.get("memref.get_global"),
transform.OperationType.get("memref.global"),
More information about the Mlir-commits
mailing list