[Mlir-commits] [mlir] [mlir][linalg] Fix memref type verification in CollapseLinalgDimensions (PR #147245)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 7 00:39:53 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (zbenzion)
<details>
<summary>Changes</summary>
When collapsing linalg dimensions we check if its memref operands are guaranteed to be collapsible. However, we currently assume that the matching indexing map is the identity map.
This commit modifies this behavior and checks if the memref is collapsible on the transformed dimensions.
---
Full diff: https://github.com/llvm/llvm-project/pull/147245.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+14-10)
- (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+46)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..9c0f6e5d6469e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1717,26 +1717,30 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
}))
return failure();
+ CollapsingInfo collapsingInfo;
+ if (failed(
+ collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
+ return rewriter.notifyMatchFailure(
+ op, "illegal to collapse specified dimensions");
+ }
+
bool hasPureBufferSemantics = op.hasPureBufferSemantics();
if (hasPureBufferSemantics &&
- !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
- MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+ !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
+ MemRefType memRefToCollapse =
+ dyn_cast<MemRefType>(opOperand.get().getType());
if (!memRefToCollapse)
return true;
+ AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
+ SmallVector<ReassociationIndices> operandReassociation =
+ getOperandReassociation(indexingMap, collapsingInfo);
return memref::CollapseShapeOp::isGuaranteedCollapsible(
- memRefToCollapse, foldedIterationDims);
+ memRefToCollapse, operandReassociation);
}))
return rewriter.notifyMatchFailure(op,
"memref is not guaranteed collapsible");
- CollapsingInfo collapsingInfo;
- if (failed(
- collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
- return rewriter.notifyMatchFailure(
- op, "illegal to collapse specified dimensions");
- }
-
// Bail on non-canonical ranges.
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 61bedecbdca5a..c8b03f8dd5151 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -100,6 +100,35 @@ func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memre
// -----
+// CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
+// CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+
+// CHECK-LABEL: func.func @collapsable_memref_projected_ops(
+// CHECK-SAME: %[[ARG0:.*]]: memref<1x24x32x8xf32>, %[[ARG1:.*]]: memref<1x24x32x8xf32>, %[[ARG2:.*]]: memref<1x24x32x8xf32, #[[$ATTR_0]]>) {
+// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
+// CHECK: %[[VAL_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32, #[[$ATTR_0]]> into memref<1x768x8xf32, strided<[7680, 10, 1]>>
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x768x8xf32>, memref<1x768x8xf32>) outs(%[[VAL_2]] : memref<1x768x8xf32, strided<[7680, 10, 1]>>) {
+// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK: linalg.yield %[[VAL_6]] : f32
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
+func.func @collapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
+ linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @uncollapsable_strided_memref(
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -119,6 +148,23 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
// -----
+// CHECK-LABEL: func @uncollapsable_memref_projected_ops(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 8 + d3)>
+func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
+ linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %0 = arith.addf %in, %in_0 : f32
+ linalg.yield %0 : f32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func.func @linalg_copy(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/147245
More information about the Mlir-commits
mailing list