[Mlir-commits] [mlir] [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations (PR #68522)

Aviad Cohen llvmlistbot at llvm.org
Sun Oct 8 04:38:32 PDT 2023


https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/68522

operations

>From 5e59fd1725a54a3ab5ab6f8d1cb46f3bb06fb8ce Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Fri, 6 Oct 2023 15:07:37 +0300
Subject: [PATCH] [mlir][linalg] Enable CollapseLinalgDimensions to collapse
 memref based operations

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++++++---
 mlir/test/Dialect/Linalg/collapse-dim.mlir    | 46 +++++++++++++++++++
 2 files changed, 81 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..6f4b0ff60ca97c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
     return operand;
 
   // Insert a reshape to collapse the dimensions.
-  auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
-      loc, operand, operandReassociation);
-  return reshapeOp.getResult();
+  if (isa<MemRefType>(operand.getType())) {
+    return builder
+        .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  } else {
+    return builder
+        .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+        .getResult();
+  }
 }
 
 /// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
       }))
     return failure();
 
+  bool hasBufferSemantics = genericOp.hasBufferSemantics();
+  if (hasBufferSemantics &&
+      !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+        if (!memRefToCollapse)
+          return true;
+
+        return memref::CollapseShapeOp::isGuaranteedCollapsible(
+            memRefToCollapse, foldedIterationDims);
+      }))
+    return rewriter.notifyMatchFailure(genericOp,
+                                       "memref is not guaranteed collapsible");
+
   CollapsingInfo collapsingInfo;
   if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
                                        foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
     Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
                                             collapsingInfo, rewriter);
     outputOperands.push_back(newOutput);
-    resultTypes.push_back(newOutput.getType());
+    // If the op has "buffer semantics", then the init operands are ranked
+    // memrefs and the op has no results.
+    if (!hasBufferSemantics)
+      resultTypes.push_back(newOutput.getType());
   }
 
   // Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
           genericOp.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
-      Value result = rewriter.create<tensor::ExpandShapeOp>(
-          loc, originalResultType, collapsedOpResult, reassociation);
-      results.push_back(result);
+      if (isa<MemRefType>(collapsedOpResult.getType())) {
+        Value result = rewriter.create<memref::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      } else {
+        Value result = rewriter.create<tensor::ExpandShapeOp>(
+            loc, originalResultType, collapsedOpResult, reassociation);
+        results.push_back(result);
+      }
     } else {
       results.push_back(collapsedOpResult);
     }
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..106154ba3a553bd 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
 // CHECK-LABEL: func @uncollapsable(
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL:   func.func private @collapsable_memref(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK:             linalg.yield %[[VAL_9]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK:         }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
+  ^bb0(%in: i32, %in_0: i32, %out: i32):
+    %0 = arith.addi %in, %in_0 : i32
+    linalg.yield %0 : i32
+  }
+  return %alloc : memref<2x6x24x48xi32>
+}



More information about the Mlir-commits mailing list