[Mlir-commits] [mlir] [mlir][transform] Implement `FlattenElementwiseLinalgOp` transform op (PR #81431)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 11 13:00:31 PST 2024
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/81431
None
>From 6e05d6a3ed218797ae264fc88f8998a0a4b945dc Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 11 Feb 2024 02:33:16 -0600
Subject: [PATCH 1/2] Implement FlattenElementwiseLinalgOp transform
---
.../Linalg/TransformOps/LinalgTransformOps.td | 42 +++++++++
.../TransformOps/LinalgTransformOps.cpp | 87 +++++++++++++++++++
2 files changed, 129 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 309573a562872f..d8d864d14ea698 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2295,6 +2295,48 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp
+//===----------------------------------------------------------------------===//
+
+def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
+ "structured.flatten_elementwise",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Flattens elementwise linalg ops.
+
+ Returns one handle:
+ - Flattened linalg operation.
+
+ #### Return modes:
+
+ Returns a definite failure if target is not isolated from above.
+ Returns a silenceable failure if the pattern application failed.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type($target, results)";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>
+ ];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::linalg::LinalgOp target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Transpose Conv2D
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..57fce5e7a749f0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3243,6 +3243,93 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
+ transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ rewriter.setInsertionPoint(target);
+ auto flatten = [&](linalg::LinalgOp op) -> FailureOr<linalg::GenericOp> {
+ if (!isElementwise(target)) {
+ return rewriter.notifyMatchFailure(
+ target, "only elementwise flattening is supported");
+ }
+ if (!llvm::all_of(target.getIndexingMapsArray(),
+ [](auto map) { return map.isMinorIdentity(); })) {
+ return rewriter.notifyMatchFailure(
+ target, "only minor identity indexing maps is supported");
+ }
+ ShapedType nonEmptyShapeType = nullptr;
+ for (const auto &resultVal : target.getDpsInitsMutable()) {
+ auto resultType = resultVal.get().getType();
+ if (ShapedType resultShapedType = dyn_cast<ShapedType>(resultType)) {
+ if (resultShapedType.getShape().empty())
+ continue;
+ if (nonEmptyShapeType == nullptr) {
+ nonEmptyShapeType = resultShapedType;
+ } else if (resultShapedType != nonEmptyShapeType) {
+ return rewriter.notifyMatchFailure(
+ target, "all operands (except rank 0) must have same types");
+ }
+ }
+ }
+ if (target.hasPureBufferSemantics()) {
+ if (!llvm::all_of(target->getOperands(), [](Value operand) {
+ if (auto memRefTy = dyn_cast<MemRefType>(operand.getType()))
+ return memRefTy.getLayout().isIdentity();
+ return true;
+ })) {
+ return rewriter.notifyMatchFailure(
+ target, "only memrefs with identity layout is supported");
+ }
+ }
+ ReassociationIndices reassociation(nonEmptyShapeType.getRank());
+ std::iota(reassociation.begin(), reassociation.end(), 0);
+ auto flattenOperand = [&](const Value &operand) {
+ return (!isa<MemRefType>(operand.getType()))
+ ? operand
+ : rewriter
+ .create<memref::CollapseShapeOp>(target.getLoc(),
+ operand, reassociation)
+ .getResult();
+ };
+ SmallVector<Value, 2> flattenedInputs(
+ llvm::map_range(target.getDpsInputs(), [&](const Value &operand) {
+ return flattenOperand(operand);
+ }));
+ SmallVector<Value, 2> flattenedInits(
+ llvm::map_range(target.getDpsInits(), [&](const Value &operand) {
+ return flattenOperand(operand);
+ }));
+
+ SmallVector<AffineMap, 4> flattenedMaps(llvm::map_range(
+ llvm::concat<Value>(flattenedInputs, flattenedInits),
+ [&](const Value &val) {
+ if (auto memRefTy = dyn_cast<MemRefType>(val.getType()))
+ return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(),
+ target.getContext());
+ return AffineMap::getMinorIdentityMap(1, 0, target.getContext());
+ }));
+
+ auto flattenedLinalgOp = rewriter.create<linalg::GenericOp>(
+ target.getLoc(), TypeRange(), flattenedInputs, flattenedInits,
+ flattenedMaps,
+ SmallVector<utils::IteratorType>{utils::IteratorType::parallel});
+ flattenedLinalgOp.getRegion().takeBody(target->getRegion(0));
+ return flattenedLinalgOp;
+ return success();
+ };
+ auto maybeFlattened = flatten(target);
+ if (failed(maybeFlattened))
+ return emitDefaultSilenceableFailure(target);
+ results.push_back(*maybeFlattened);
+ rewriter.eraseOp(target);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// TransposeConv2DOp
//===----------------------------------------------------------------------===//
>From aff79baad62b53f8f10f733d5ff3c0068556549d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 11 Feb 2024 14:57:07 -0600
Subject: [PATCH 2/2] Add a couple regression tests
---
.../TransformOps/LinalgTransformOps.cpp | 50 +++++++-----
.../Dialect/Linalg/flatten-elementwise.mlir | 77 +++++++++++++++++++
2 files changed, 106 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/flatten-elementwise.mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 57fce5e7a749f0..15f7f82e24f3a5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3252,19 +3252,22 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
- auto flatten = [&](linalg::LinalgOp op) -> FailureOr<linalg::GenericOp> {
+ if (target.getNumLoops() <= 1)
+ return DiagnosedSilenceableFailure::success();
+ auto flatten = [&](linalg::LinalgOp &op) -> FailureOr<linalg::LinalgOp> {
if (!isElementwise(target)) {
return rewriter.notifyMatchFailure(
target, "only elementwise flattening is supported");
}
+ // TODO: Support broadcasting and permutations
if (!llvm::all_of(target.getIndexingMapsArray(),
[](auto map) { return map.isMinorIdentity(); })) {
return rewriter.notifyMatchFailure(
target, "only minor identity indexing maps is supported");
}
ShapedType nonEmptyShapeType = nullptr;
- for (const auto &resultVal : target.getDpsInitsMutable()) {
- auto resultType = resultVal.get().getType();
+ for (const auto &resultVal : target->getOperands()) {
+ auto resultType = resultVal.getType();
if (ShapedType resultShapedType = dyn_cast<ShapedType>(resultType)) {
if (resultShapedType.getShape().empty())
continue;
@@ -3277,6 +3280,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
}
}
if (target.hasPureBufferSemantics()) {
+ // TODO: Relax restrictions on layout
if (!llvm::all_of(target->getOperands(), [](Value operand) {
if (auto memRefTy = dyn_cast<MemRefType>(operand.getType()))
return memRefTy.getLayout().isIdentity();
@@ -3285,8 +3289,11 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
return rewriter.notifyMatchFailure(
target, "only memrefs with identity layout is supported");
}
+ } else {
+ // TODO: Support tensors
+ return rewriter.notifyMatchFailure(target, "tensors are not supported");
}
- ReassociationIndices reassociation(nonEmptyShapeType.getRank());
+ ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto flattenOperand = [&](const Value &operand) {
return (!isa<MemRefType>(operand.getType()))
@@ -3296,37 +3303,38 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
operand, reassociation)
.getResult();
};
- SmallVector<Value, 2> flattenedInputs(
- llvm::map_range(target.getDpsInputs(), [&](const Value &operand) {
- return flattenOperand(operand);
- }));
- SmallVector<Value, 2> flattenedInits(
- llvm::map_range(target.getDpsInits(), [&](const Value &operand) {
+ SmallVector<Value, 2> flattenedOperands(
+ llvm::map_range(target->getOperands(), [&](const Value &operand) {
return flattenOperand(operand);
}));
- SmallVector<AffineMap, 4> flattenedMaps(llvm::map_range(
- llvm::concat<Value>(flattenedInputs, flattenedInits),
- [&](const Value &val) {
+ SmallVector<AffineMap, 4> flattenedMaps(
+ llvm::map_range(flattenedOperands, [&](const Value &val) {
if (auto memRefTy = dyn_cast<MemRefType>(val.getType()))
return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(),
target.getContext());
return AffineMap::getMinorIdentityMap(1, 0, target.getContext());
}));
- auto flattenedLinalgOp = rewriter.create<linalg::GenericOp>(
- target.getLoc(), TypeRange(), flattenedInputs, flattenedInits,
- flattenedMaps,
- SmallVector<utils::IteratorType>{utils::IteratorType::parallel});
- flattenedLinalgOp.getRegion().takeBody(target->getRegion(0));
- return flattenedLinalgOp;
- return success();
+ rewriter.modifyOpInPlace(op, [&]() {
+ op->setOperands(flattenedOperands);
+ // TODO: Find a more general way to determine if op requires explicit
+ // indexing_maps and iterator_types
+ if (isa<linalg::GenericOp>(op)) {
+ op->setAttr("indexing_maps",
+ rewriter.getAffineMapArrayAttr(flattenedMaps));
+ op->setAttr(
+ "iterator_types",
+ rewriter.getArrayAttr({IteratorTypeAttr::get(
+ rewriter.getContext(), utils::IteratorType::parallel)}));
+ }
+ });
+ return op;
};
auto maybeFlattened = flatten(target);
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back(*maybeFlattened);
- rewriter.eraseOp(target);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
new file mode 100644
index 00000000000000..e360fc3ff51784
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @fill(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
+// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
+func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
+ linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @map(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+ linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @generic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
+// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
+// CHECK-NEXT: linalg.yield %[[SUM]]
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+ linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %0 = arith.addf %a, %b : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %flattened = transform.structured.flatten_elementwise %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list