[Mlir-commits] [mlir] [mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (PR #120465)

Kunwar Grover llvmlistbot at llvm.org
Fri Dec 27 08:38:24 PST 2024


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/120465

>From ab58988015d599ddf9390b6ab49a3d4827c3755b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 18 Dec 2024 15:45:52 +0000
Subject: [PATCH 1/3] Fix invalid use of PartialReductionOpInterface in
 MeshShardingInteraceImpl

---
 .../Transforms/MeshShardingInterfaceImpl.cpp  | 34 ++++++++++++-------
 1 file changed, 21 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 5bf2f91c2c7bc8..92cfba2549a3f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
                       ArrayRef<MeshSharding> resultShardings,
                       SymbolTableCollection &symbolTable) {
-  for (const MeshSharding& sharding : operandShardings) {
+  for (const MeshSharding &sharding : operandShardings) {
     if (sharding) {
       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
   }
 
-  for (const MeshSharding& sharding : resultShardings) {
+  for (const MeshSharding &sharding : resultShardings) {
     if (sharding) {
       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
 // the original operand.
 // The other processes would use the reduction operation neutral tensor.
 static Value createDestinationPassingStyleInitOperand(
-    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
-    MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+    LinalgOp op, int operandNumber, Value spmdizedOperand,
+    ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+    ImplicitLocOpBuilder &builder) {
   Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
       meshOp.getSymName(), reductionMeshAxes, builder);
   Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
     builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
     SmallVector<OpFoldResult> shape =
         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
-    PartialReductionOpInterface partialReductionIface =
-        llvm::cast<PartialReductionOpInterface>(op.getOperation());
-    assert(op->getNumResults() == 1 && "Multiple results not supported.");
-    FailureOr<SmallVector<Value>> reductionNeutralTensor =
-        partialReductionIface.generateInitialTensorForPartialReduction(
-            builder, builder.getLoc(), shape, {});
-    assert(succeeded(reductionNeutralTensor));
-    builder.create<scf::YieldOp>(reductionNeutralTensor.value());
+
+    SmallVector<Operation *> combinerOps;
+    matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
+    assert(combinerOps.size() == 1);
+    std::optional<TypedAttr> neutralEl =
+        arith::getNeutralElement(combinerOps[0]);
+
+    Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
+                                                 neutralEl.value().getType());
+    Value constant =
+        builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
+    Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
+                     .getResult(0);
+
+    builder.create<scf::YieldOp>(fill);
   }
   return ifOp.getResult(0);
 }
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
   Value spmdizedInitOperand =
       spmdizationMap.lookup(op->getOperands()[operandIdx]);
   newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
-      op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+      op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
   return newOperands;
 }
 

>From e5a56f5fd975a54cefe1c950b2e547a53a279759 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 18 Dec 2024 18:22:01 +0000
Subject: [PATCH 2/3] [mlir][scf] Add getPartialResultTilePosition to
 PartialReductionOpInterface

---
 .../mlir/Interfaces/TilingInterface.td        |  22 +++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 157 ++++++++++++------
 .../SCF/Transforms/TileUsingInterface.cpp     |  28 ++--
 .../Linalg/transform-tile-reduction.mlir      |  67 ++++++--
 4 files changed, 196 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index b75fc5e806afbe..50b69b8f8d833e 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
         /*defaultImplementation=*/[{
           return failure();
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of the partial result tile computed by
+          the tiled operation. This is same as
+          TilingInterface:::getResultTilePosition, but determines the result
+          tile position for partial reduction.
+        }],
+        /*retType=*/"::llvm::LogicalResult",
+        /*methodName=*/"getPartialResultTilePosition",
+        /*args=*/(ins
+            "::mlir::OpBuilder &":$b,
+            "unsigned":$resultNumber,
+            "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
+            "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
+            "::mlir::ArrayRef<int>":$reductionDims),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f86715a94b268a..098016cd0fd226 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -324,7 +324,20 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// External model implementation of PartialReductionInterface for LinalgOps.
+static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
+                                           ArrayRef<int> reductionDims,
+                                           unsigned resultNumber) {
+  AffineMap map =
+      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
+  for (int redPos : reductionDims) {
+    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                           map.getNumResults());
+  }
+  return map;
+}
+
+/// External model implementation of PartialReductionInterface for
+/// LinalgOps.
 template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +351,24 @@ struct LinalgOpPartialReductionInterface
     if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
 
+    // LinalgOp implements TilingInterface.
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    SmallVector<OpFoldResult> shape =
+        llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
+                            [](Range x) { return x.size; });
+
+    SmallVector<OpFoldResult> tiledShape;
+    for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
+      if (isZeroIndex(tileSize)) {
+        tiledShape.push_back(dimSize);
+      } else {
+        tiledShape.push_back(tileSize);
+      }
+    }
+
     SmallVector<Value> inits;
     for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
          ++initIdx) {
-      // Insert the new parallel dimension based on the index of the reduction
-      // loops. This could be controlled by user for more flexibility.
       SmallVector<Operation *, 4> combinerOps;
       if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
                           combinerOps) ||
@@ -355,33 +381,19 @@ struct LinalgOpPartialReductionInterface
         return op->emitOpError(
             "Failed to get an identity value for the reduction operation.");
 
-      ArrayRef<int64_t> oldShape =
-          linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
-
-      // Calculate the new shape, we insert the new dimensions based on the
-      // index of the reduction dimensions.
-      SmallVector<int64_t> newOutputShape;
-      SmallVector<Value> dynamicDims;
-      int64_t currReductionDims = 0;
-      DenseSet<int> reductionDimsSet(reductionDims.begin(),
-                                     reductionDims.end());
-      for (int64_t idx :
-           llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
-        if (reductionDimsSet.contains(idx)) {
-          dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
-          currReductionDims++;
-          continue;
-        }
-        int64_t oldIdx = idx - currReductionDims;
-        int64_t dim = oldShape[oldIdx];
-        newOutputShape.push_back(dim);
-        if (ShapedType::isDynamic(dim))
-          dynamicDims.push_back(b.create<tensor::DimOp>(
-              loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
+      // Append the new partial result dimensions.
+      AffineMap partialMap =
+          getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
+      SmallVector<OpFoldResult> partialResultShape;
+      for (AffineExpr dimExpr : partialMap.getResults()) {
+        auto dim = cast<AffineDimExpr>(dimExpr);
+        partialResultShape.push_back(tiledShape[dim.getPosition()]);
       }
-      Value emptyTensor = b.create<tensor::EmptyOp>(
-          loc, newOutputShape,
-          linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
+
+      Type elType =
+          getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+      Value emptyTensor =
+          b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
       Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
       auto identityTensor =
           b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +419,7 @@ struct LinalgOpPartialReductionInterface
       // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
       // this with a for range loop when we have it.
       AffineMap newMap =
-          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
-      for (int redPos : reductionDims) {
-        newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
-                                     newMap.getNumResults());
-      }
+          getPartialResultAffineMap(linalgOp, reductionDims, idx);
       newInitMaps.push_back(newMap);
     }
 
@@ -476,29 +484,74 @@ struct LinalgOpPartialReductionInterface
                                          Location loc, ValueRange partialReduce,
                                          ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-    SmallVector<int64_t> reductionDimsInt64(reductionDims);
-    auto reduction = b.create<linalg::ReduceOp>(
-        loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
-        [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
-          int64_t numInits = linalgOp.getNumDpsInits();
-          SmallVector<Value> yieldedValues;
-          for (int idx : llvm::seq<int>(0, numInits)) {
+
+    // Permute the reduction dims as permuted by the partial result map.
+
+    int64_t numInits = linalgOp.getNumDpsInits();
+    SmallVector<Operation *> mergeOperations;
+    SmallVector<Value> replacements;
+    for (int idx : llvm::seq(numInits)) {
+      // linalg.reduce's iteration space is the result's iteration space (and
+      // not the operations iteration space). To account for this, permute the
+      // reduction dimensions based on the partial result map.
+      AffineMap partialMap =
+          getPartialResultAffineMap(linalgOp, reductionDims, idx);
+      SmallVector<int64_t> partialReductionDims;
+      for (auto [resultNum, dimExpr] :
+           llvm::enumerate(partialMap.getResults())) {
+        unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+        if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+          partialReductionDims.push_back(resultNum);
+        }
+      }
+
+      Value partialResult = partialReduce[idx];
+      Value init = linalgOp.getDpsInits()[idx];
+
+      auto reduction = b.create<linalg::ReduceOp>(
+          loc, partialResult, init, partialReductionDims,
+          [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
             // Get the combiner op.
             SmallVector<Operation *, 4> combinerOps;
             matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
             Operation *clonedReductionOp = b.clone(*combinerOps[0]);
             // Combine the input at idx and output at numInits + idx.
-            clonedReductionOp->setOperand(0, inputs[idx]);
-            clonedReductionOp->setOperand(1, inputs[numInits + idx]);
-            // Yield.
-            yieldedValues.push_back(clonedReductionOp->getResult(0));
-          }
-          b.create<linalg::YieldOp>(loc, yieldedValues);
-        });
-    return MergeResult{
-        {reduction.getOperation()},
-        llvm::map_to_vector(reduction->getResults(),
-                            [](OpResult r) -> Value { return r; })};
+            clonedReductionOp->setOperand(0, inputs[0]);
+            clonedReductionOp->setOperand(1, inputs[1]);
+            b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+          });
+
+      mergeOperations.push_back(reduction);
+      replacements.push_back(reduction->getResult(0));
+    }
+
+    return MergeResult{mergeOperations, replacements};
+  }
+
+  LogicalResult getPartialResultTilePosition(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &resultOffsets,
+      SmallVector<OpFoldResult> &resultSizes,
+      ArrayRef<int> reductionDims) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    AffineMap partialMap =
+        getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
+    for (AffineExpr dimExpr : partialMap.getResults()) {
+      unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+      resultSizes.push_back(sizes[dim]);
+
+      if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+        // Reduction dims are reduced, and are always outputed in the same
+        // place. So use offset 0 for them.
+        resultOffsets.push_back(b.getIndexAttr(0));
+      } else {
+        resultOffsets.push_back(offsets[dim]);
+      }
+    }
+
+    return success();
   }
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2277989bf8411b..b548f8ce8b560b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
                                     resultOffset, resultSize);
   case scf::SCFTilingOptions::ReductionTilingStrategy::
       PartialReductionOuterReduction: {
-    // TODO: This does not work for non identity accesses to the result tile.
-    // The proper fix is to add a getPartialResultTilePosition method to
-    // PartialReductionOpInterface.
-    resultOffset =
-        SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
-    for (size_t i = 0; i < offsets.size(); i++) {
-      resultSize.push_back(
-          tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only supported"
+              "for operations implementing PartialReductionOpInterface");
     }
-    return success();
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+                                              resultOffset, resultSize,
+                                              reductionDims);
+  }
   default:
     return rewriter.notifyMatchFailure(op,
                                        "unhandled reduction tiling strategy");
   }
-  }
 }
 
 static FailureOr<MergeResult>
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index cce4b4efa61c8b..9d34c80822d0e1 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
 //     CHECK:     %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
@@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 //     CHECK: func @reduction_tile_transpose
-//     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
-//     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
+//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x5xf32>
+//     CHECK:   linalg.fill {{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   scf.for
-//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
 //     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
-//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
-//     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+//     CHECK:     scf.yield {{.*}} : tensor<?x5xf32>
 //     CHECK:   }
 //     CHECK:   linalg.reduce
 //     CHECK:   return
@@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor<?x?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 //     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
-      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
       transform.yield
   }
 }
@@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
 // CHECK:       linalg.reduce
 // CHECK:         arith.addf
+// CHECK:       linalg.reduce
 // CHECK:         arith.maximumf
+
+// -----
+
+func.func @reduction_tile_multi_dim_transpose(%arg0: tensor<?x?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d2, d0)>],
+   iterator_types = ["parallel", "reduction", "parallel"]}
+   ins(%arg0 : tensor<?x?x?xf32>)
+   outs(%out : tensor<?x?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %42 = arith.addf %arg7, %arg9 : f32
+      linalg.yield %42 : f32
+    } -> tensor<?x?xf32>
+  return %red : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+  }
+}
+
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+//     CHECK: func @reduction_tile_multi_dim_transpose
+//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x?x5xf32>
+//     CHECK:   linalg.fill {{.*}} : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
+//     CHECK:   scf.for
+//     CHECK:     %[[K:.*]] = affine.min
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?x?xf32>
+//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?x?xf32>) outs(%[[EXT]] : tensor<?x?x?xf32>)
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x5xf32>
+//     CHECK:     scf.yield {{.*}} : tensor<?x?x5xf32>
+//     CHECK:   }
+//     CHECK:   linalg.reduce
+//     CHECK:   return

>From 93d07b9d3c89fe5f264bdab06611f83aa51089ac Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 27 Dec 2024 16:38:03 +0000
Subject: [PATCH 3/3] Address comments

---
 .../Linalg/Transforms/TilingInterfaceImpl.cpp      | 14 +++++++++++---
 1 file changed, 11 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 098016cd0fd226..b7764da26a7f47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -324,6 +324,13 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
+/// Return an AffineMap for a partial result for the given result number,
+/// assuming the partial tiling strategy is outer-reduction loop +
+/// inner-parallel tile. The returned AffineMap can be used as the replacement
+/// AffineMap for the inner-parallel tile linalg op for the given result number.
+///
+/// The new AffineMap is the old AffineMap with reduction dimensions appended
+/// at end.
 static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
                                            ArrayRef<int> reductionDims,
                                            unsigned resultNumber) {
@@ -491,9 +498,10 @@ struct LinalgOpPartialReductionInterface
     SmallVector<Operation *> mergeOperations;
     SmallVector<Value> replacements;
     for (int idx : llvm::seq(numInits)) {
-      // linalg.reduce's iteration space is the result's iteration space (and
-      // not the operations iteration space). To account for this, permute the
-      // reduction dimensions based on the partial result map.
+      // linalg.reduce's iteration space is the tiled result's iteration space
+      // (and not the tiled operation's iteration space). To account for this,
+      // permute the reduction dimensions based on the partial result map of the
+      // tiled result.
       AffineMap partialMap =
           getPartialResultAffineMap(linalgOp, reductionDims, idx);
       SmallVector<int64_t> partialReductionDims;



More information about the Mlir-commits mailing list