[Mlir-commits] [mlir] [mlir][transform] Implement `FlattenElementwiseLinalgOp` transform op (PR #81431)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 21:19:57 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/81431

>From 6e05d6a3ed218797ae264fc88f8998a0a4b945dc Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 11 Feb 2024 02:33:16 -0600
Subject: [PATCH 1/4] Implement FlattenElementwiseLinalgOp transform

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 42 +++++++++
 .../TransformOps/LinalgTransformOps.cpp       | 87 +++++++++++++++++++
 2 files changed, 129 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 309573a562872f..d8d864d14ea698 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2295,6 +2295,48 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp
+//===----------------------------------------------------------------------===//
+
+def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
+    "structured.flatten_elementwise",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Flattens elementwise linalg ops.
+
+    Returns one handle:
+    - Flattened linalg operation.
+
+    #### Return modes:
+
+    Returns a definite failure if target is not isolated from above.
+    Returns a silenceable failure if the pattern application failed.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Transpose Conv2D
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 585fd14b40d764..57fce5e7a749f0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3243,6 +3243,93 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// FlattenElementwiseLinalgOp.
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto flatten = [&](linalg::LinalgOp op) -> FailureOr<linalg::GenericOp> {
+    if (!isElementwise(target)) {
+      return rewriter.notifyMatchFailure(
+          target, "only elementwise flattening is supported");
+    }
+    if (!llvm::all_of(target.getIndexingMapsArray(),
+                      [](auto map) { return map.isMinorIdentity(); })) {
+      return rewriter.notifyMatchFailure(
+          target, "only minor identity indexing maps is supported");
+    }
+    ShapedType nonEmptyShapeType = nullptr;
+    for (const auto &resultVal : target.getDpsInitsMutable()) {
+      auto resultType = resultVal.get().getType();
+      if (ShapedType resultShapedType = dyn_cast<ShapedType>(resultType)) {
+        if (resultShapedType.getShape().empty())
+          continue;
+        if (nonEmptyShapeType == nullptr) {
+          nonEmptyShapeType = resultShapedType;
+        } else if (resultShapedType != nonEmptyShapeType) {
+          return rewriter.notifyMatchFailure(
+              target, "all operands (except rank 0) must have same types");
+        }
+      }
+    }
+    if (target.hasPureBufferSemantics()) {
+      if (!llvm::all_of(target->getOperands(), [](Value operand) {
+            if (auto memRefTy = dyn_cast<MemRefType>(operand.getType()))
+              return memRefTy.getLayout().isIdentity();
+            return true;
+          })) {
+        return rewriter.notifyMatchFailure(
+            target, "only memrefs with identity layout is supported");
+      }
+    }
+    ReassociationIndices reassociation(nonEmptyShapeType.getRank());
+    std::iota(reassociation.begin(), reassociation.end(), 0);
+    auto flattenOperand = [&](const Value &operand) {
+      return (!isa<MemRefType>(operand.getType()))
+                 ? operand
+                 : rewriter
+                       .create<memref::CollapseShapeOp>(target.getLoc(),
+                                                        operand, reassociation)
+                       .getResult();
+    };
+    SmallVector<Value, 2> flattenedInputs(
+        llvm::map_range(target.getDpsInputs(), [&](const Value &operand) {
+          return flattenOperand(operand);
+        }));
+    SmallVector<Value, 2> flattenedInits(
+        llvm::map_range(target.getDpsInits(), [&](const Value &operand) {
+          return flattenOperand(operand);
+        }));
+
+    SmallVector<AffineMap, 4> flattenedMaps(llvm::map_range(
+        llvm::concat<Value>(flattenedInputs, flattenedInits),
+        [&](const Value &val) {
+          if (auto memRefTy = dyn_cast<MemRefType>(val.getType()))
+            return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(),
+                                                  target.getContext());
+          return AffineMap::getMinorIdentityMap(1, 0, target.getContext());
+        }));
+
+    auto flattenedLinalgOp = rewriter.create<linalg::GenericOp>(
+        target.getLoc(), TypeRange(), flattenedInputs, flattenedInits,
+        flattenedMaps,
+        SmallVector<utils::IteratorType>{utils::IteratorType::parallel});
+    flattenedLinalgOp.getRegion().takeBody(target->getRegion(0));
+    return flattenedLinalgOp;
+    return success();
+  };
+  auto maybeFlattened = flatten(target);
+  if (failed(maybeFlattened))
+    return emitDefaultSilenceableFailure(target);
+  results.push_back(*maybeFlattened);
+  rewriter.eraseOp(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeConv2DOp
 //===----------------------------------------------------------------------===//

>From aff79baad62b53f8f10f733d5ff3c0068556549d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sun, 11 Feb 2024 14:57:07 -0600
Subject: [PATCH 2/4] Add a couple regression tests

---
 .../TransformOps/LinalgTransformOps.cpp       | 50 +++++++-----
 .../Dialect/Linalg/flatten-elementwise.mlir   | 77 +++++++++++++++++++
 2 files changed, 106 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/flatten-elementwise.mlir

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 57fce5e7a749f0..15f7f82e24f3a5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3252,19 +3252,22 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  auto flatten = [&](linalg::LinalgOp op) -> FailureOr<linalg::GenericOp> {
+  if (target.getNumLoops() <= 1)
+    return DiagnosedSilenceableFailure::success();
+  auto flatten = [&](linalg::LinalgOp &op) -> FailureOr<linalg::LinalgOp> {
     if (!isElementwise(target)) {
       return rewriter.notifyMatchFailure(
           target, "only elementwise flattening is supported");
     }
+    // TODO: Support broadcasting and permutations
     if (!llvm::all_of(target.getIndexingMapsArray(),
                       [](auto map) { return map.isMinorIdentity(); })) {
       return rewriter.notifyMatchFailure(
           target, "only minor identity indexing maps is supported");
     }
     ShapedType nonEmptyShapeType = nullptr;
-    for (const auto &resultVal : target.getDpsInitsMutable()) {
-      auto resultType = resultVal.get().getType();
+    for (const auto &resultVal : target->getOperands()) {
+      auto resultType = resultVal.getType();
       if (ShapedType resultShapedType = dyn_cast<ShapedType>(resultType)) {
         if (resultShapedType.getShape().empty())
           continue;
@@ -3277,6 +3280,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
       }
     }
     if (target.hasPureBufferSemantics()) {
+      // TODO: Relax restrictions on layout
       if (!llvm::all_of(target->getOperands(), [](Value operand) {
             if (auto memRefTy = dyn_cast<MemRefType>(operand.getType()))
               return memRefTy.getLayout().isIdentity();
@@ -3285,8 +3289,11 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
         return rewriter.notifyMatchFailure(
             target, "only memrefs with identity layout is supported");
       }
+    } else {
+      // TODO: Support tensors
+      return rewriter.notifyMatchFailure(target, "tensors are not supported");
     }
-    ReassociationIndices reassociation(nonEmptyShapeType.getRank());
+    ReassociationIndices reassociation(target.getNumLoops());
     std::iota(reassociation.begin(), reassociation.end(), 0);
     auto flattenOperand = [&](const Value &operand) {
       return (!isa<MemRefType>(operand.getType()))
@@ -3296,37 +3303,38 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
                                                         operand, reassociation)
                        .getResult();
     };
-    SmallVector<Value, 2> flattenedInputs(
-        llvm::map_range(target.getDpsInputs(), [&](const Value &operand) {
-          return flattenOperand(operand);
-        }));
-    SmallVector<Value, 2> flattenedInits(
-        llvm::map_range(target.getDpsInits(), [&](const Value &operand) {
+    SmallVector<Value, 2> flattenedOperands(
+        llvm::map_range(target->getOperands(), [&](const Value &operand) {
           return flattenOperand(operand);
         }));
 
-    SmallVector<AffineMap, 4> flattenedMaps(llvm::map_range(
-        llvm::concat<Value>(flattenedInputs, flattenedInits),
-        [&](const Value &val) {
+    SmallVector<AffineMap, 4> flattenedMaps(
+        llvm::map_range(flattenedOperands, [&](const Value &val) {
           if (auto memRefTy = dyn_cast<MemRefType>(val.getType()))
             return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(),
                                                   target.getContext());
           return AffineMap::getMinorIdentityMap(1, 0, target.getContext());
         }));
 
-    auto flattenedLinalgOp = rewriter.create<linalg::GenericOp>(
-        target.getLoc(), TypeRange(), flattenedInputs, flattenedInits,
-        flattenedMaps,
-        SmallVector<utils::IteratorType>{utils::IteratorType::parallel});
-    flattenedLinalgOp.getRegion().takeBody(target->getRegion(0));
-    return flattenedLinalgOp;
-    return success();
+    rewriter.modifyOpInPlace(op, [&]() {
+      op->setOperands(flattenedOperands);
+      // TODO: Find a more general way to determine if op requires explicit
+      // indexing_maps and iterator_types
+      if (isa<linalg::GenericOp>(op)) {
+        op->setAttr("indexing_maps",
+                    rewriter.getAffineMapArrayAttr(flattenedMaps));
+        op->setAttr(
+            "iterator_types",
+            rewriter.getArrayAttr({IteratorTypeAttr::get(
+                rewriter.getContext(), utils::IteratorType::parallel)}));
+      }
+    });
+    return op;
   };
   auto maybeFlattened = flatten(target);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
   results.push_back(*maybeFlattened);
-  rewriter.eraseOp(target);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
new file mode 100644
index 00000000000000..e360fc3ff51784
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @fill(
+// CHECK-SAME:                  %[[ARG0:.*]]: f32,
+// CHECK-SAME:                  %[[ARG1:.*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
+func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
+    linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @map(
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+    linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func.func @generic
+// CHECK-SAME:                 %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-SAME:                 %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
+// CHECK-NEXT:    %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
+// CHECK-NEXT:    %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
+// CHECK-NEXT:    linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
+// CHECK-NEXT:       ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
+// CHECK-NEXT:         %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
+// CHECK-NEXT:         linalg.yield %[[SUM]]
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
+    linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
+        ^bb0(%a: f32, %b: f32, %c: f32):
+            %0 = arith.addf %a, %b : f32
+            linalg.yield %0 : f32
+    }
+    return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %flattened = transform.structured.flatten_elementwise %0
+      : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
\ No newline at end of file

>From cd0ebb1051264dbffd4c0fb1a386150a05ff6ef2 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 13 Feb 2024 22:27:00 -0600
Subject: [PATCH 3/4] Refactor `collapseOpIterationDims` to work for all linalg
 ops

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 60 ++++++++-----------
 2 files changed, 27 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a848d12fbbb50e..a566745185ad99 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1081,9 +1081,8 @@ bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
 /// When valid, the method also collapses the operands of the op. Returns
 /// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-template <typename LinalgType>
 FailureOr<SmallVector<Value>>
-collapseOpIterationDims(LinalgType op,
+collapseOpIterationDims(LinalgOp op,
                         ArrayRef<ReassociationIndices> foldedIterationDims,
                         RewriterBase &rewriter);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 286b07669a47f5..ce58caa6c39aad 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1449,12 +1449,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   }
 }
 
-template <typename LinalgType>
-Operation *createCollapsedOp(LinalgType op,
-                             const CollapsingInfo &collapsingInfo,
-                             RewriterBase &rewriter) {
-  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
-                "unsupported linalg op type to create");
+LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
+                           RewriterBase &rewriter) {
   Location loc = op->getLoc();
 
   // Get the input operands.
@@ -1479,14 +1475,17 @@ Operation *createCollapsedOp(LinalgType op,
       resultTypes.push_back(newOutput.getType());
   }
 
-  if (isa<linalg::CopyOp>(op)) {
-    return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
-                                           outputOperands[0]);
-  }
+  Operation *collapsedOp = clone(
+      rewriter, op, resultTypes,
+      llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));
 
   // Get the iterator types for the operand.
-  SmallVector<utils::IteratorType> iteratorTypes =
-      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
+  SmallVector<Attribute> iteratorTypes = llvm::map_to_vector(
+      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo),
+      [&](utils::IteratorType itTy) {
+        return cast<Attribute>(
+            IteratorTypeAttr::get(rewriter.getContext(), itTy));
+      });
 
   // Get the indexing maps.
   auto indexingMaps =
@@ -1494,25 +1493,22 @@ Operation *createCollapsedOp(LinalgType op,
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       });
 
-  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
-      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
-      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
-  Block *origOpBlock = &op->getRegion(0).front();
-  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
-  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
-                       collapsedOpBlock->getArguments());
+  // TODO: Find a more general way to determine if op requires explicit
+  // indexing_maps and iterator_types
+  if (isa<linalg::GenericOp>(op)) {
+    collapsedOp->setAttr("indexing_maps",
+                         rewriter.getAffineMapArrayAttr(indexingMaps));
+    collapsedOp->setAttr("iterator_types",
+                         rewriter.getArrayAttr(iteratorTypes));
+  }
 
-  return collapsedOp;
+  return cast<LinalgOp>(collapsedOp);
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-template <typename LinalgType>
 FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
-    LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
+    LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
-  static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
-                "unsupported linalg op type to collapse");
-
   // Bail on trivial no-op cases.
   if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
       llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1541,8 +1537,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
   }
 
   // Bail on non-canonical ranges.
-  SmallVector<Range> loopRanges =
-      cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
+  SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
       return cast<IntegerAttr>(attr).getInt() == value;
@@ -1558,8 +1553,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
         op, "expected all loop ranges to have zero start and unit stride");
   }
 
-  LinalgType collapsedOp = cast<LinalgType>(
-      createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
+  LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
   if (collapsedOp.hasIndexSemantics()) {
@@ -1632,9 +1626,8 @@ class FoldWithProducerReshapeOpByCollapsing
         continue;
       }
 
-      std::optional<SmallVector<Value>> replacements =
-          collapseOpIterationDims<linalg::GenericOp>(
-              genericOp, collapsableIterationDims, rewriter);
+      std::optional<SmallVector<Value>> replacements = collapseOpIterationDims(
+          genericOp, collapsableIterationDims, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
@@ -1675,8 +1668,7 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
     }
 
     std::optional<SmallVector<Value>> replacements =
-        collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
-                                            rewriter);
+        collapseOpIterationDims(op, collapsableIterationDims, rewriter);
     if (!replacements) {
       return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }

>From 9d0b35a3a4d84179ab88357e9208ea9ab4b149ea Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 13 Feb 2024 23:19:12 -0600
Subject: [PATCH 4/4] Refactor `FlattenElementwiseLinalgOp` to use
 `collapseOpIterationDims`

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 +++-
 .../TransformOps/LinalgTransformOps.cpp       | 40 ++-----------------
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 16 ++++----
 3 files changed, 17 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a566745185ad99..65cf19e7a4fcd6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1074,6 +1074,11 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
                               ArrayRef<ReassociationIndices> dimSequences);
 
+struct CollapseResult {
+  SmallVector<Value> results;
+  LinalgOp collapsedOp;
+};
+
 /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
 /// to calling this method is that for each list in `foldedIterationDim`, the
 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of
@@ -1081,7 +1086,7 @@ bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
 /// When valid, the method also collapses the operands of the op. Returns
 /// replacement values of the results of the original `linalgOp` by inserting
 /// reshapes to get back values of compatible types.
-FailureOr<SmallVector<Value>>
+FailureOr<CollapseResult>
 collapseOpIterationDims(LinalgOp op,
                         ArrayRef<ReassociationIndices> foldedIterationDims,
                         RewriterBase &rewriter);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 15f7f82e24f3a5..25e72ab273833e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3254,7 +3254,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
   rewriter.setInsertionPoint(target);
   if (target.getNumLoops() <= 1)
     return DiagnosedSilenceableFailure::success();
-  auto flatten = [&](linalg::LinalgOp &op) -> FailureOr<linalg::LinalgOp> {
+  auto flatten = [&](linalg::LinalgOp &op) -> FailureOr<CollapseResult> {
     if (!isElementwise(target)) {
       return rewriter.notifyMatchFailure(
           target, "only elementwise flattening is supported");
@@ -3295,46 +3295,12 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
     }
     ReassociationIndices reassociation(target.getNumLoops());
     std::iota(reassociation.begin(), reassociation.end(), 0);
-    auto flattenOperand = [&](const Value &operand) {
-      return (!isa<MemRefType>(operand.getType()))
-                 ? operand
-                 : rewriter
-                       .create<memref::CollapseShapeOp>(target.getLoc(),
-                                                        operand, reassociation)
-                       .getResult();
-    };
-    SmallVector<Value, 2> flattenedOperands(
-        llvm::map_range(target->getOperands(), [&](const Value &operand) {
-          return flattenOperand(operand);
-        }));
-
-    SmallVector<AffineMap, 4> flattenedMaps(
-        llvm::map_range(flattenedOperands, [&](const Value &val) {
-          if (auto memRefTy = dyn_cast<MemRefType>(val.getType()))
-            return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(),
-                                                  target.getContext());
-          return AffineMap::getMinorIdentityMap(1, 0, target.getContext());
-        }));
-
-    rewriter.modifyOpInPlace(op, [&]() {
-      op->setOperands(flattenedOperands);
-      // TODO: Find a more general way to determine if op requires explicit
-      // indexing_maps and iterator_types
-      if (isa<linalg::GenericOp>(op)) {
-        op->setAttr("indexing_maps",
-                    rewriter.getAffineMapArrayAttr(flattenedMaps));
-        op->setAttr(
-            "iterator_types",
-            rewriter.getArrayAttr({IteratorTypeAttr::get(
-                rewriter.getContext(), utils::IteratorType::parallel)}));
-      }
-    });
-    return op;
+    return collapseOpIterationDims(op, reassociation, rewriter);
   };
   auto maybeFlattened = flatten(target);
   if (failed(maybeFlattened))
     return emitDefaultSilenceableFailure(target);
-  results.push_back(*maybeFlattened);
+  results.push_back((*maybeFlattened).collapsedOp);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ce58caa6c39aad..b81b67474565ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1506,7 +1506,7 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
+FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
     LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
     RewriterBase &rewriter) {
   // Bail on trivial no-op cases.
@@ -1594,7 +1594,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
       results.push_back(collapsedOpResult);
     }
   }
-  return results;
+  return CollapseResult{.results = results, .collapsedOp = collapsedOp};
 }
 
 namespace {
@@ -1626,14 +1626,14 @@ class FoldWithProducerReshapeOpByCollapsing
         continue;
       }
 
-      std::optional<SmallVector<Value>> replacements = collapseOpIterationDims(
+      std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
           genericOp, collapsableIterationDims, rewriter);
-      if (!replacements) {
+      if (!collapseResult) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");
       }
 
-      rewriter.replaceOp(genericOp, *replacements);
+      rewriter.replaceOp(genericOp, (*collapseResult).results);
       return success();
     }
     return failure();
@@ -1667,12 +1667,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
           op, "specified dimensions cannot be collapsed");
     }
 
-    std::optional<SmallVector<Value>> replacements =
+    std::optional<CollapseResult> collapseResult =
         collapseOpIterationDims(op, collapsableIterationDims, rewriter);
-    if (!replacements) {
+    if (!collapseResult) {
       return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
     }
-    rewriter.replaceOp(op, *replacements);
+    rewriter.replaceOp(op, (*collapseResult).results);
     return success();
   }
 



More information about the Mlir-commits mailing list