[Mlir-commits] [mlir] 6033544 - [mlir][linalg] Fix memref type verification in CollapseLinalgDimensions (#147245)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 9 01:04:13 PDT 2025


Author: zbenzion
Date: 2025-07-09T01:04:08-07:00
New Revision: 603354417375394eb5a9bab2780b6e585da3f628

URL: https://github.com/llvm/llvm-project/commit/603354417375394eb5a9bab2780b6e585da3f628
DIFF: https://github.com/llvm/llvm-project/commit/603354417375394eb5a9bab2780b6e585da3f628.diff

LOG: [mlir][linalg] Fix memref type verification in CollapseLinalgDimensions (#147245)

When collapsing linalg dimensions we check if its memref operands are
guaranteed to be collapsible. However, we currently assume that the
matching indexing map is the identity map.

This commit modifies this behavior and checks if the memref is
collapsible on the transformed dimensions.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/collapse-dim.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..9c0f6e5d6469e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1717,26 +1717,30 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
       }))
     return failure();
 
+  CollapsingInfo collapsingInfo;
+  if (failed(
+          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
+    return rewriter.notifyMatchFailure(
+        op, "illegal to collapse specified dimensions");
+  }
+
   bool hasPureBufferSemantics = op.hasPureBufferSemantics();
   if (hasPureBufferSemantics &&
-      !llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
-        MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
+      !llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
+        MemRefType memRefToCollapse =
+            dyn_cast<MemRefType>(opOperand.get().getType());
         if (!memRefToCollapse)
           return true;
 
+        AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
+        SmallVector<ReassociationIndices> operandReassociation =
+            getOperandReassociation(indexingMap, collapsingInfo);
         return memref::CollapseShapeOp::isGuaranteedCollapsible(
-            memRefToCollapse, foldedIterationDims);
+            memRefToCollapse, operandReassociation);
       }))
     return rewriter.notifyMatchFailure(op,
                                        "memref is not guaranteed collapsible");
 
-  CollapsingInfo collapsingInfo;
-  if (failed(
-          collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
-    return rewriter.notifyMatchFailure(
-        op, "illegal to collapse specified dimensions");
-  }
-
   // Bail on non-canonical ranges.
   SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
   auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {

diff  --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 61bedecbdca5a..c8b03f8dd5151 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -100,6 +100,35 @@ func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memre
 
 // -----
 
+// CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
+// CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+
+// CHECK-LABEL:   func.func @collapsable_memref_projected_ops(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<1x24x32x8xf32>, %[[ARG1:.*]]: memref<1x24x32x8xf32>, %[[ARG2:.*]]: memref<1x24x32x8xf32, #[[$ATTR_0]]>) {
+// CHECK:           %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
+// CHECK:           %[[VAL_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32, #[[$ATTR_0]]> into memref<1x768x8xf32, strided<[7680, 10, 1]>>
+// CHECK:           linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x768x8xf32>, memref<1x768x8xf32>) outs(%[[VAL_2]] : memref<1x768x8xf32, strided<[7680, 10, 1]>>) {
+// CHECK:           ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
+// CHECK:             %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
+// CHECK:             linalg.yield %[[VAL_6]] : f32
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
+func.func @collapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
+  linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @uncollapsable_strided_memref(
 //       CHECK:   linalg.generic
 //  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -119,6 +148,23 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
 
 // -----
 
+// CHECK-LABEL: func @uncollapsable_memref_projected_ops(
+//       CHECK:   linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 8 + d3)>
+func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
+  linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %0 = arith.addf %in, %in_0 : f32
+    linalg.yield %0 : f32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @linalg_copy(
 // CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
 // CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {


        


More information about the Mlir-commits mailing list