[Mlir-commits] [mlir] [mlir][linalg] Fix crash in tile_reduction when output map has constant exprs (PR #189166)

Mehdi Amini llvmlistbot at llvm.org
Fri Apr 3 03:51:26 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189166

>From 073dfa0ae2e3de26a6ff681b789001c5e6dae8e4 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 2 Apr 2026 03:46:03 -0700
Subject: [PATCH] [mlir][linalg] Fix crash in tile_reduction when output map
 has constant exprs
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

The `getInitSliceInfo*` helpers and `generateInitialTensorForPartialReduction`
unconditionally cast every AffineMap result to `AffineDimExpr`. When the
output indexing map contains a constant expression (e.g.
`affine_map<(d0,d1,d2)->(d0,0,d2)>`), the cast triggers an assertion.

A constant index means the op always accesses a fixed position in that
dimension — it does not imply the dimension has size 1. The fix detects
`AffineConstantExpr` at the four affected call sites and uses the actual
output operand dimension size instead.

Add regression tests covering dim size > 1, dynamic shapes, and multiple
consecutive constants, for both `tile_reduction_using_for` and
`tile_reduction_using_forall`.

Fixes #173025

Assisted-by: Claude Code
---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  80 +++--
 .../Linalg/transform-tile-reduction.mlir      | 301 ++++++++++++++++++
 2 files changed, 363 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 558ebdebd65c5..7ed07e1ec9a01 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -434,13 +434,22 @@ struct InitSliceInfo {
 static InitSliceInfo getInitSliceInfoForOuterReduction(
     MLIRContext *context, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
-    ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
+    ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
+    ArrayRef<OpFoldResult> initOperandShape) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
   Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
   Attribute one = IntegerAttr::get(IndexType::get(context), 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+  for (auto [resultIdx, dimExpr] :
+       llvm::enumerate(partialReductionMap.getResults())) {
+    if (isa<AffineConstantExpr>(dimExpr)) {
+      // A constant index in the output map accesses a fixed position; keep
+      // the full output dimension to match the original output operand shape.
+      initOffsets.push_back(zero);
+      initSizes.push_back(initOperandShape[resultIdx]);
+      continue;
+    }
     unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
     if (reductionDims.contains(dim)) {
       initOffsets.push_back(zero);
@@ -460,13 +469,24 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
 static InitSliceInfo getInitSliceInfoForOuterParallel(
     MLIRContext *context, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
-    ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
+    ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
+    ArrayRef<OpFoldResult> initOperandShape) {
   int64_t initRank = partialReductionMap.getNumResults();
   SmallVector<OpFoldResult> initOffsets, initSizes;
+  Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
   Attribute one = IntegerAttr::get(IndexType::get(context), 1);
   SmallVector<OpFoldResult> initStrides(initRank, one);
   SmallVector<OpFoldResult> resultShape;
-  for (AffineExpr dimExpr : partialReductionMap.getResults()) {
+  for (auto [resultIdx, dimExpr] :
+       llvm::enumerate(partialReductionMap.getResults())) {
+    if (isa<AffineConstantExpr>(dimExpr)) {
+      // A constant index accesses a fixed position; keep the full output
+      // dimension to match the original output operand shape.
+      initOffsets.push_back(zero);
+      initSizes.push_back(initOperandShape[resultIdx]);
+      resultShape.push_back(initOperandShape[resultIdx]);
+      continue;
+    }
     unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
     if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
       initOffsets.push_back(splitReductionIvs[dimPos.value()]);
@@ -490,17 +510,18 @@ static InitSliceInfo getInitSliceInfo(MLIRContext *context,
                                       ArrayRef<OpFoldResult> sizes,
                                       const SetVector<unsigned> &reductionDims,
                                       ArrayRef<OpFoldResult> splitReductionIvs,
-                                      AffineMap partialReductionMap) {
+                                      AffineMap partialReductionMap,
+                                      ArrayRef<OpFoldResult> initOperandShape) {
   if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
-    return getInitSliceInfoForOuterReduction(context, offsets, sizes,
-                                             reductionDims, splitReductionIvs,
-                                             partialReductionMap);
+    return getInitSliceInfoForOuterReduction(
+        context, offsets, sizes, reductionDims, splitReductionIvs,
+        partialReductionMap, initOperandShape);
   }
   assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
          "unexpected ReductionTilingStrategy");
-  return getInitSliceInfoForOuterParallel(context, offsets, sizes,
-                                          reductionDims, splitReductionIvs,
-                                          partialReductionMap);
+  return getInitSliceInfoForOuterParallel(
+      context, offsets, sizes, reductionDims, splitReductionIvs,
+      partialReductionMap, initOperandShape);
 }
 
 /// External model implementation of PartialReductionInterface for
@@ -538,7 +559,17 @@ struct LinalgOpPartialReductionInterface
 
       // Append the new partial result dimensions.
       SmallVector<OpFoldResult> partialResultShape;
-      for (AffineExpr dimExpr : partialMap.getResults()) {
+      Value initValue = linalgOp.getDpsInits()[initIdx];
+      SmallVector<OpFoldResult> initShape =
+          tensor::getMixedSizes(b, loc, initValue);
+      for (auto [resultIdx, dimExpr] :
+           llvm::enumerate(partialMap.getResults())) {
+        if (isa<AffineConstantExpr>(dimExpr)) {
+          // A constant index in the output map accesses a fixed position; use
+          // the actual output dimension size (not a hardcoded 1).
+          partialResultShape.push_back(initShape[resultIdx]);
+          continue;
+        }
         auto dim = cast<AffineDimExpr>(dimExpr);
         partialResultShape.push_back(sizes[dim.getPosition()]);
       }
@@ -591,11 +622,15 @@ struct LinalgOpPartialReductionInterface
 
     // Step 2b: Extract a slice of the init operands.
     SmallVector<Value, 1> tiledInits;
-    for (auto [partialReductionMap, valueToTile] :
-         llvm::zip_equal(partialReductionMaps, init)) {
+    for (auto [partialReductionMap, valueToTile, initOperandValue] :
+         llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) {
+      // Compute the actual shape of the original init operand for handling
+      // constant expressions in the partial reduction map.
+      SmallVector<OpFoldResult> initOperandShape =
+          tensor::getMixedSizes(b, loc, initOperandValue);
       InitSliceInfo sliceInfo = getInitSliceInfo(
           b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
-          splitReductionIvs, partialReductionMap);
+          splitReductionIvs, partialReductionMap, initOperandShape);
       auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
       RankedTensorType sliceResultType = RankedTensorType::get(
           sliceInfo.resultShape, valueToTileType.getElementType(),
@@ -670,6 +705,8 @@ struct LinalgOpPartialReductionInterface
       SmallVector<int64_t> partialReductionDims;
       for (auto [resultNum, dimExpr] :
            llvm::enumerate(partialMap.getResults())) {
+        if (isa<AffineConstantExpr>(dimExpr))
+          continue; // Constant dims are never reduction dims.
         unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
         if (llvm::is_contained(reductionDims, dim)) {
           partialReductionDims.push_back(resultNum);
@@ -707,9 +744,16 @@ struct LinalgOpPartialReductionInterface
     auto linalgOp = cast<LinalgOp>(op);
     SmallVector<AffineMap> partialReductionMaps =
         getPartialResultAffineMaps(linalgOp, reductionDims);
-    InitSliceInfo sliceInfo = getInitSliceInfo(
-        b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
-        splitReductionIvs, partialReductionMaps[resultNumber]);
+    // Compute the actual shape of the init operand for handling constant
+    // expressions in the partial reduction map.
+    Value initOperandValue = linalgOp.getDpsInits()[resultNumber];
+    Location loc = op->getLoc();
+    SmallVector<OpFoldResult> initOperandShape =
+        tensor::getMixedSizes(b, loc, initOperandValue);
+    InitSliceInfo sliceInfo =
+        getInitSliceInfo(b.getContext(), tilingStrategy, offsets, sizes,
+                         reductionDims, splitReductionIvs,
+                         partialReductionMaps[resultNumber], initOperandShape);
     std::swap(resultOffsets, sliceInfo.offsets);
     std::swap(resultSizes, sliceInfo.sizes);
 
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index e31d4f333557c..25af6796d1f63 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -728,3 +728,304 @@ module attributes {transform.with_named_sequence} {
 //       CHECK:     linalg.generic
 //       CHECK:       %[[LOCAL_IDX:.+]] = linalg.index 1 : index
 //       CHECK:       affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]
+
+// -----
+
+// Verify that tile_reduction_using_forall handles output indexing maps that
+// contain constant expressions (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`)
+// without crashing. Previously, generateInitialTensorForPartialReduction
+// unconditionally cast every map result to AffineDimExpr, triggering an
+// assertion when a constant expression was present (issue #173025).
+
+func.func @reduction_tile_with_constant_in_output_map(
+    %arg0: tensor<1x4096x64xf32>,
+    %arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+    ],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<1x1x64xf32>
+  return %0 : tensor<1x1x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1:4 = transform.structured.tile_reduction_using_forall %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @reduction_tile_with_constant_in_output_map
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty() : tensor<1x1x64x1024xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32>
+//       CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x1x64x1024xf32>) {
+//       CHECK:     %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
+//       CHECK:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
+//       CHECK:     %[[PARTIAL:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[IN_SLICE]] :
+//  CHECK-SAME:         outs(%[[INIT_SLICE]] :
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   linalg.reduce ins(%[[L]] : tensor<1x1x64x1024xf32>) outs(%arg1 : tensor<1x1x64xf32>) dimensions = [3]
+//       CHECK:   return %{{.*}} : tensor<1x1x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_forall with a constant in the output map when
+// the constant-indexed dimension size is greater than 1 (K=3). The partial
+// init tensor must use the actual output dim size (3), not a hardcoded 1.
+
+func.func @reduction_tile_forall_constant_dim_k_gt_1(
+    %arg0: tensor<1x4096x64xf32>,
+    %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+    ],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<1x3x64xf32>
+  return %0 : tensor<1x3x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1:4 = transform.structured.tile_reduction_using_forall %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @reduction_tile_forall_constant_dim_k_gt_1
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty() : tensor<1x3x64x1024xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x1024xf32>) -> tensor<1x3x64x1024xf32>
+//       CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x1024xf32>) {
+//       CHECK:     %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
+//       CHECK:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
+//       CHECK:     %[[PARTIAL:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[IN_SLICE]] :
+//  CHECK-SAME:         outs(%[[INIT_SLICE]] :
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   linalg.reduce ins(%[[L]] : tensor<1x3x64x1024xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
+//       CHECK:   return %{{.*}} : tensor<1x3x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_for with a constant in the output map when
+// the constant-indexed dimension size is greater than 1 (K=3). The partial
+// init tensor must use the actual output dim size (3), not a hardcoded 1.
+
+func.func @reduction_tile_for_constant_dim_k_gt_1(
+    %arg0: tensor<1x4096x64xf32>,
+    %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+    ],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<1x3x64xf32>
+  return %0 : tensor<1x3x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @reduction_tile_for_constant_dim_k_gt_1
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty() : tensor<1x3x64x4xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x4xf32>) -> tensor<1x3x64x4xf32>
+//       CHECK:   %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x4xf32>) {
+//       CHECK:     %[[PARTIAL:.+]] = linalg.generic
+//  CHECK-SAME:         outs(%[[ARG3]] :
+//       CHECK:     scf.yield %[[PARTIAL]]
+//       CHECK:   }
+//       CHECK:   linalg.reduce ins(%[[L]] : tensor<1x3x64x4xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
+//       CHECK:   return %{{.*}} : tensor<1x3x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_forall handles dynamic output shapes combined
+// with a constant expression in the output map. The partial init tensor must
+// use tensor.dim to query the dynamic dimension at the constant-indexed
+// position rather than a hardcoded 1.
+
+// CHECK-LABEL: func @reduction_tile_dynamic_constant_map
+func.func @reduction_tile_dynamic_constant_map(
+    %arg0: tensor<?x4096x?xf32>,
+    %arg1: tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+    ],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%arg0 : tensor<?x4096x?xf32>) outs(%arg1 : tensor<?x3x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x3x?xf32>
+  return %0 : tensor<?x3x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1:4 = transform.structured.tile_reduction_using_forall %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// Verify that the partial init tensor uses the correct dynamic shape. The
+// constant-indexed dim 1 is static (size 3) so the partial tensor is
+// tensor<?x3x?x1024xf32>. The extract_slice within the forall body uses
+// size 3 at the constant-indexed position, not a hardcoded 1.
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty({{.*}}) : tensor<?x3x?x1024xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<?x3x?x1024xf32>)
+//       CHECK:   scf.forall
+//       CHECK:     tensor.extract_slice {{.*}} [{{.*}}, 3, {{.*}}, 1]
+//       CHECK:   linalg.reduce ins({{.*}} : tensor<?x3x?x1024xf32>) {{.*}} dimensions = [3]
+//       CHECK:   return %{{.*}}
+
+// -----
+
+// Verify tile_reduction_using_forall handles two consecutive constant
+// expressions in the same output map (e.g. `(d0,d1,d2)->(0,0,d2)`).
+// Both constant-indexed dimensions must use the actual output dim size.
+
+// CHECK-LABEL: func @reduction_tile_two_constants_in_map
+func.func @reduction_tile_two_constants_in_map(
+    %arg0: tensor<1x4096x64xf32>,
+    %arg1: tensor<3x5x64xf32>) -> tensor<3x5x64xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (0, 0, d2)>
+    ],
+    iterator_types = ["reduction", "reduction", "parallel"]
+  } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<3x5x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<3x5x64xf32>
+  return %0 : tensor<3x5x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1:4 = transform.structured.tile_reduction_using_forall %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty() : tensor<3x5x64x1024xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<3x5x64x1024xf32>) -> tensor<3x5x64x1024xf32>
+//       CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<3x5x64x1024xf32>) {
+//       CHECK:     tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
+//       CHECK:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1]
+//       CHECK:     linalg.generic
+//       CHECK:     scf.forall.in_parallel {
+//       CHECK:       tensor.parallel_insert_slice {{.*}} into %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1]
+//       CHECK:     }
+//       CHECK:   }
+//       CHECK:   linalg.reduce ins(%[[L]] : tensor<3x5x64x1024xf32>) outs(%arg1 : tensor<3x5x64xf32>) dimensions = [3]
+//       CHECK:   return %{{.*}} : tensor<3x5x64xf32>
+
+// -----
+
+// Verify tile_reduction_using_for handles dynamic output shapes combined with
+// a constant expression in the output map. The partial init tensor must use
+// tensor.dim to query the dynamic dimensions rather than hardcoding 1.
+
+// CHECK-LABEL: func @reduction_tile_for_dynamic_constant_map
+func.func @reduction_tile_for_dynamic_constant_map(
+    %arg0: tensor<?x4096x?xf32>,
+    %arg1: tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+      affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+    ],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%arg0 : tensor<?x4096x?xf32>) outs(%arg1 : tensor<?x3x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = arith.addf %in, %out : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x3x?xf32>
+  return %0 : tensor<?x3x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg
+        : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+        by tile_sizes = [0, 4, 0]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op,
+                                  !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// Verify the partial init tensor uses tensor.dim for dynamic dims and the
+// static constant-indexed position uses size 3. The extract_slice inside
+// the for loop body must use size 3 at the constant-indexed position.
+//   CHECK-DAG:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[E:.*]] = tensor.empty({{.*}}) : tensor<?x3x?x4xf32>
+//       CHECK:   %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<?x3x?x4xf32>)
+//       CHECK:   %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x3x?x4xf32>) {
+//       CHECK:     tensor.extract_slice %arg0[{{.*}}] [{{.*}}, 4, {{.*}}] [1, 1, 1]
+//       CHECK:     tensor.extract_slice %[[ARG3]][{{.*}}] [{{.*}}, 3, {{.*}}, 4] [1, 1, 1, 1]
+//       CHECK:   }
+//       CHECK:   linalg.reduce ins(%[[L]] : tensor<?x3x?x4xf32>) outs(%arg1 : tensor<?x3x?xf32>) dimensions = [3]
+//       CHECK:   return %{{.*}}



More information about the Mlir-commits mailing list