[Mlir-commits] [mlir] d4ae7ee - [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations (#68522)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 10 13:36:17 PDT 2023
Author: Aviad Cohen
Date: 2023-10-10T23:36:11+03:00
New Revision: d4ae7ee662d2f318c0e4105c674e0634733b48eb
URL: https://github.com/llvm/llvm-project/commit/d4ae7ee662d2f318c0e4105c674e0634733b48eb
DIFF: https://github.com/llvm/llvm-project/commit/d4ae7ee662d2f318c0e4105c674e0634733b48eb.diff
LOG: [mlir][linalg] Enable CollapseLinalgDimensions to collapse memref based operations (#68522)
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/collapse-dim.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 069c613cc246d6a..6f4b0ff60ca97c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1388,9 +1388,15 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
return operand;
// Insert a reshape to collapse the dimensions.
- auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
- loc, operand, operandReassociation);
- return reshapeOp.getResult();
+ if (isa<MemRefType>(operand.getType())) {
+ return builder
+ .create<memref::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ } else {
+ return builder
+ .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation)
+ .getResult();
+ }
}
/// Modify the `linalg.index` operations in the original generic op, to its
@@ -1444,6 +1450,19 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
}))
return failure();
+ bool hasBufferSemantics = genericOp.hasBufferSemantics();
+ if (hasBufferSemantics &&
+ !llvm::all_of(genericOp->getOperands(), [&](Value operand) -> bool {
+ MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+ if (!memRefToCollapse)
+ return true;
+
+ return memref::CollapseShapeOp::isGuaranteedCollapsible(
+ memRefToCollapse, foldedIterationDims);
+ }))
+ return rewriter.notifyMatchFailure(genericOp,
+ "memref is not guaranteed collapsible");
+
CollapsingInfo collapsingInfo;
if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
foldedIterationDims))) {
@@ -1499,7 +1518,10 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
Value newOutput = getCollapsedOpOperand(loc, genericOp, &output,
collapsingInfo, rewriter);
outputOperands.push_back(newOutput);
- resultTypes.push_back(newOutput.getType());
+ // If the op has "buffer semantics", then the init operands are ranked
+ // memrefs and the op has no results.
+ if (!hasBufferSemantics)
+ resultTypes.push_back(newOutput.getType());
}
// Create the generic op.
@@ -1538,9 +1560,15 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
genericOp.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
- Value result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
- results.push_back(result);
+ if (isa<MemRefType>(collapsedOpResult.getType())) {
+ Value result = rewriter.create<memref::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ } else {
+ Value result = rewriter.create<tensor::ExpandShapeOp>(
+ loc, originalResultType, collapsedOpResult, reassociation);
+ results.push_back(result);
+ }
} else {
results.push_back(collapsedOpResult);
}
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 6737a6e15da5afe..106154ba3a553bd 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -70,3 +70,49 @@ func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41
// CHECK-LABEL: func @uncollapsable(
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+// -----
+
+// CHECK-LABEL: func.func private @collapsable_memref(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
+// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
+// CHECK: ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
+// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
+// CHECK: linalg.yield %[[VAL_9]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_2]] : memref<1x24x32x8xf32>
+// CHECK: }
+
+func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return %alloc : memref<1x24x32x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @uncollapsable_strided_memref(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
+ %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
+ linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %0 = arith.addi %in, %in_0 : i32
+ linalg.yield %0 : i32
+ }
+ return %alloc : memref<2x6x24x48xi32>
+}
More information about the Mlir-commits
mailing list