[llvm-branch-commits] [mlir] b2391d5 - [MLIR] Normalize the results of normalizable operations
Uday Bondhugula via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 3 06:09:33 PST 2020
Author: Haruki Imai
Date: 2020-12-03T19:34:07+05:30
New Revision: b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4
URL: https://github.com/llvm/llvm-project/commit/b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4
DIFF: https://github.com/llvm/llvm-project/commit/b2391d5f0da2a9122d337de6fa1a6aa9baad3fb4.diff
LOG: [MLIR] Normalize the results of normalizable operations
Memrefs with affine_map in the results of normalizable operation were
not normalized by `--normalize-memrefs` option. This patch normalizes
them.
Differential Revision: https://reviews.llvm.org/D88719
Added:
Modified:
mlir/lib/Transforms/NormalizeMemRefs.cpp
mlir/test/Transforms/normalize-memrefs-ops.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index 44b3ccbd2c3f..d7fa212baa73 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -36,6 +36,7 @@ struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
DenseSet<FuncOp> &normalizableFuncs);
+ Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
};
} // end anonymous namespace
@@ -384,6 +385,59 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
funcOp.front().eraseArgument(argIndex + 1);
}
+ // Walk over normalizable operations to normalize memrefs of the operation
+ // results. When `op` has memrefs with affine map in the operation results,
+ // new operation containin normalized memrefs is created. Then, the memrefs
+ // are replaced. `CallOp` is skipped here because it is handled in
+ // `updateFunctionSignature()`.
+ funcOp.walk([&](Operation *op) {
+ if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
+ op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) {
+ // Create newOp containing normalized memref in the operation result.
+ Operation *newOp = createOpResultsNormalized(funcOp, op);
+ // When all of the operation results have no memrefs or memrefs without
+ // affine map, `newOp` is the same with `op` and following process is
+ // skipped.
+ if (op != newOp) {
+ bool replacingMemRefUsesFailed = false;
+ for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
+ // Replace all uses of the old memrefs.
+ Value oldMemRef = op->getResult(resIndex);
+ Value newMemRef = newOp->getResult(resIndex);
+ MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
+ // Check whether the operation result is MemRef type.
+ if (!oldMemRefType)
+ continue;
+ MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
+ if (oldMemRefType == newMemRefType)
+ continue;
+ // TODO: Assume single layout map. Multiple maps not supported.
+ AffineMap layoutMap = oldMemRefType.getAffineMaps().front();
+ if (failed(replaceAllMemRefUsesWith(oldMemRef,
+ /*newMemRef=*/newMemRef,
+ /*extraIndices=*/{},
+ /*indexRemap=*/layoutMap,
+ /*extraOperands=*/{},
+ /*symbolOperands=*/{},
+ /*domInstFilter=*/nullptr,
+ /*postDomInstFilter=*/nullptr,
+ /*allowDereferencingOps=*/true,
+ /*replaceInDeallocOp=*/true))) {
+ newOp->erase();
+ replacingMemRefUsesFailed = true;
+ continue;
+ }
+ }
+ if (!replacingMemRefUsesFailed) {
+ // Replace other ops with new op and delete the old op when the
+ // replacement succeeded.
+ op->replaceAllUsesWith(newOp);
+ op->erase();
+ }
+ }
+ }
+ });
+
// In a normal function, memrefs in the return type signature gets normalized
// as a result of normalization of functions arguments, AllocOps or CallOps'
// result types. Since an external function doesn't have a body, memrefs in
@@ -417,3 +471,49 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
}
updateFunctionSignature(funcOp, moduleOp);
}
+
+/// Create an operation containing normalized memrefs in the operation results.
+/// When the results of `oldOp` have memrefs with affine map, the memrefs are
+/// normalized, and new operation containing them in the operation results is
+/// returned. If all of the results of `oldOp` have no memrefs or memrefs
+/// without affine map, `oldOp` is returned without modification.
+Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
+ Operation *oldOp) {
+ // Prepare OperationState to create newOp containing normalized memref in
+ // the operation results.
+ OperationState result(oldOp->getLoc(), oldOp->getName());
+ result.addOperands(oldOp->getOperands());
+ result.addAttributes(oldOp->getAttrs());
+ // Add normalized MemRefType to the OperationState.
+ SmallVector<Type, 4> resultTypes;
+ OpBuilder b(funcOp);
+ bool resultTypeNormalized = false;
+ for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
+ auto resultType = oldOp->getResult(resIndex).getType();
+ MemRefType memrefType = resultType.dyn_cast<MemRefType>();
+ // Check whether the operation result is MemRef type.
+ if (!memrefType) {
+ resultTypes.push_back(resultType);
+ continue;
+ }
+ // Fetch a new memref type after normalizing the old memref.
+ MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
+ /*numSymbolicOperands=*/0);
+ if (newMemRefType == memrefType) {
+ // Either memrefType already had an identity map or the map couldn't
+ // be transformed to an identity map.
+ resultTypes.push_back(memrefType);
+ continue;
+ }
+ resultTypes.push_back(newMemRefType);
+ resultTypeNormalized = true;
+ }
+ result.addTypes(resultTypes);
+ // When all of the results of `oldOp` have no memrefs or memrefs without
+ // affine map, `oldOp` is returned without modification.
+ if (resultTypeNormalized) {
+ OpBuilder bb(oldOp);
+ return bb.createOperation(result);
+ } else
+ return oldOp;
+}
diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
index 92b23c07887f..9567ab5f8f83 100644
--- a/mlir/test/Transforms/normalize-memrefs-ops.mlir
+++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir
@@ -14,13 +14,13 @@
// Test with op_norm and maps in arguments and in the operations in the function.
// CHECK-LABEL: test_norm
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
- // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+ // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32>
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
return
@@ -29,13 +29,13 @@ func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
// Same test with op_nonnorm, with maps in the arguments and the operations in the function.
// CHECK-LABEL: test_nonnorm
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map>)
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #map>)
func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
- // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map>
+ // CHECK: %[[v0:.*]] = alloc() : memref<1x16x14x14xf32, #map>
// CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map>, memref<1x16x14x14xf32, #map>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map>
return
@@ -44,13 +44,13 @@ func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
// Test with op_norm, with maps in the operations in the function.
// CHECK-LABEL: test_norm_mix
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>
func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map0>
"test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
dealloc %0 : memref<1x16x14x14xf32, #map0>
- // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+ // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32>
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
return
@@ -61,12 +61,12 @@ func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
// CHECK-LABEL: test_load_store
-// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32>
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32>
func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
%0 = alloc() : memref<1x16x14x14xf32, #map_tile>
- // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32>
+ // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32>
%1 = alloc() : memref<1x16x14x14xf32>
- // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32>
+ // CHECK: %[[v1:.*]] = alloc() : memref<1x16x14x14xf32>
"test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
// CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
%cst = constant 3.0 : f32
@@ -90,6 +90,25 @@ func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
return
}
+// Test with op_norm_ret, with maps in the results of normalizable operation.
+
+// CHECK-LABEL: test_norm_ret
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
+func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
+ %0 = alloc() : memref<1x16x14x14xf32, #map_tile>
+ // CHECK-NEXT: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32>
+ %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
+ // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
+ // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
+ "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> ()
+ // CHECK-NEXT: "test.op_norm"
+ // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> ()
+ dealloc %0 : memref<1x16x14x14xf32, #map_tile>
+ // CHECK-NEXT: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
+ return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>
+ // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>
+}
+
// Test with an arbitrary op that references the function symbol.
"test.op_funcref"() {func = @test_norm_mix} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 995022ab2115..7d9274fbddee 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -666,6 +666,11 @@ def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
def OpNonNorm : TEST_Op<"op_nonnorm"> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
}
+// Test for memrefs normalization of an op that has normalizable memref results.
+def OpNormRet : TEST_Op<"op_norm_ret", [MemRefsNormalizable]> {
+ let arguments = (ins AnyMemRef:$X);
+ let results = (outs AnyMemRef:$Y, AnyMemRef:$Z);
+}
// Test for memrefs normalization of an op with a reference to a function
// symbol.
More information about the llvm-branch-commits
mailing list