[Mlir-commits] [mlir] 91bbebc - [mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (#120465)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 27 08:52:38 PST 2024


Author: Kunwar Grover
Date: 2024-12-27T16:52:34Z
New Revision: 91bbebc7e118cceae1fc0e349de08094a3cd2fe7

URL: https://github.com/llvm/llvm-project/commit/91bbebc7e118cceae1fc0e349de08094a3cd2fe7
DIFF: https://github.com/llvm/llvm-project/commit/91bbebc7e118cceae1fc0e349de08094a3cd2fe7.diff

LOG: [mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (#120465)

This PR adds a new interface method to PartialReductionOpInterface which
allows it to query the result tile position for the partial result.
Previously, tiling the reduction dimension with
SplitReductionOuterReduction when the result has transposed parallel
dimensions would produce wrong results.

Other fixes that were needed to make this PR work:

- Instead of ad-hoc logic to decide where to place the new reduction
dimensions in the partial result based on the iteration space, the
reduction dimensions are always appended to the partial result tensor.
- Remove usage of PartialReductionOpInterface in Mesh dialect. The
implementation was trying to just get a neutral element, but ended up
trying to use PartialReductionOpInterface for it, which is not right. It
was also passing the wrong sizes to it.

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index b75fc5e806afbe..50b69b8f8d833e 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
         /*defaultImplementation=*/[{
           return failure();
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of the partial result tile computed by
+          the tiled operation. This is same as
+          TilingInterface:::getResultTilePosition, but determines the result
+          tile position for partial reduction.
+        }],
+        /*retType=*/"::llvm::LogicalResult",
+        /*methodName=*/"getPartialResultTilePosition",
+        /*args=*/(ins
+            "::mlir::OpBuilder &":$b,
+            "unsigned":$resultNumber,
+            "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
+            "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
+            "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
+            "::mlir::ArrayRef<int>":$reductionDims),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
       >
   ];
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 5bf2f91c2c7bc8..92cfba2549a3f2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
 static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
                       ArrayRef<MeshSharding> resultShardings,
                       SymbolTableCollection &symbolTable) {
-  for (const MeshSharding& sharding : operandShardings) {
+  for (const MeshSharding &sharding : operandShardings) {
     if (sharding) {
       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
   }
 
-  for (const MeshSharding& sharding : resultShardings) {
+  for (const MeshSharding &sharding : resultShardings) {
     if (sharding) {
       return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
     }
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
 // the original operand.
 // The other processes would use the reduction operation neutral tensor.
 static Value createDestinationPassingStyleInitOperand(
-    LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
-    MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+    LinalgOp op, int operandNumber, Value spmdizedOperand,
+    ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+    ImplicitLocOpBuilder &builder) {
   Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
       meshOp.getSymName(), reductionMeshAxes, builder);
   Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
     builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
     SmallVector<OpFoldResult> shape =
         tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
-    PartialReductionOpInterface partialReductionIface =
-        llvm::cast<PartialReductionOpInterface>(op.getOperation());
-    assert(op->getNumResults() == 1 && "Multiple results not supported.");
-    FailureOr<SmallVector<Value>> reductionNeutralTensor =
-        partialReductionIface.generateInitialTensorForPartialReduction(
-            builder, builder.getLoc(), shape, {});
-    assert(succeeded(reductionNeutralTensor));
-    builder.create<scf::YieldOp>(reductionNeutralTensor.value());
+
+    SmallVector<Operation *> combinerOps;
+    matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
+    assert(combinerOps.size() == 1);
+    std::optional<TypedAttr> neutralEl =
+        arith::getNeutralElement(combinerOps[0]);
+
+    Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
+                                                 neutralEl.value().getType());
+    Value constant =
+        builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
+    Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
+                     .getResult(0);
+
+    builder.create<scf::YieldOp>(fill);
   }
   return ifOp.getResult(0);
 }
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
   Value spmdizedInitOperand =
       spmdizationMap.lookup(op->getOperands()[operandIdx]);
   newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
-      op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+      op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
   return newOperands;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f86715a94b268a..b7764da26a7f47 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
 // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
 
-/// External model implementation of PartialReductionInterface for LinalgOps.
+/// Return an AffineMap for a partial result for the given result number,
+/// assuming the partial tiling strategy is outer-reduction loop +
+/// inner-parallel tile. The returned AffineMap can be used as the replacement
+/// AffineMap for the inner-parallel tile linalg op for the given result number.
+///
+/// The new AffineMap is the old AffineMap with reduction dimensions appended
+/// at end.
+static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
+                                           ArrayRef<int> reductionDims,
+                                           unsigned resultNumber) {
+  AffineMap map =
+      linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
+  for (int redPos : reductionDims) {
+    map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
+                           map.getNumResults());
+  }
+  return map;
+}
+
+/// External model implementation of PartialReductionInterface for
+/// LinalgOps.
 template <typename LinalgOpTy>
 struct LinalgOpPartialReductionInterface
     : public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
     if (linalgOp.hasPureBufferSemantics())
       return op->emitOpError("expected operation to have tensor semantics");
 
+    // LinalgOp implements TilingInterface.
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    SmallVector<OpFoldResult> shape =
+        llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
+                            [](Range x) { return x.size; });
+
+    SmallVector<OpFoldResult> tiledShape;
+    for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
+      if (isZeroIndex(tileSize)) {
+        tiledShape.push_back(dimSize);
+      } else {
+        tiledShape.push_back(tileSize);
+      }
+    }
+
     SmallVector<Value> inits;
     for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
          ++initIdx) {
-      // Insert the new parallel dimension based on the index of the reduction
-      // loops. This could be controlled by user for more flexibility.
       SmallVector<Operation *, 4> combinerOps;
       if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
                           combinerOps) ||
@@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
         return op->emitOpError(
             "Failed to get an identity value for the reduction operation.");
 
-      ArrayRef<int64_t> oldShape =
-          linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
-
-      // Calculate the new shape, we insert the new dimensions based on the
-      // index of the reduction dimensions.
-      SmallVector<int64_t> newOutputShape;
-      SmallVector<Value> dynamicDims;
-      int64_t currReductionDims = 0;
-      DenseSet<int> reductionDimsSet(reductionDims.begin(),
-                                     reductionDims.end());
-      for (int64_t idx :
-           llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
-        if (reductionDimsSet.contains(idx)) {
-          dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
-          currReductionDims++;
-          continue;
-        }
-        int64_t oldIdx = idx - currReductionDims;
-        int64_t dim = oldShape[oldIdx];
-        newOutputShape.push_back(dim);
-        if (ShapedType::isDynamic(dim))
-          dynamicDims.push_back(b.create<tensor::DimOp>(
-              loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
+      // Append the new partial result dimensions.
+      AffineMap partialMap =
+          getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
+      SmallVector<OpFoldResult> partialResultShape;
+      for (AffineExpr dimExpr : partialMap.getResults()) {
+        auto dim = cast<AffineDimExpr>(dimExpr);
+        partialResultShape.push_back(tiledShape[dim.getPosition()]);
       }
-      Value emptyTensor = b.create<tensor::EmptyOp>(
-          loc, newOutputShape,
-          linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
+
+      Type elType =
+          getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
+      Value emptyTensor =
+          b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
       Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
       auto identityTensor =
           b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
       // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
       // this with a for range loop when we have it.
       AffineMap newMap =
-          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
-      for (int redPos : reductionDims) {
-        newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
-                                     newMap.getNumResults());
-      }
+          getPartialResultAffineMap(linalgOp, reductionDims, idx);
       newInitMaps.push_back(newMap);
     }
 
@@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
                                          Location loc, ValueRange partialReduce,
                                          ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-    SmallVector<int64_t> reductionDimsInt64(reductionDims);
-    auto reduction = b.create<linalg::ReduceOp>(
-        loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
-        [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
-          int64_t numInits = linalgOp.getNumDpsInits();
-          SmallVector<Value> yieldedValues;
-          for (int idx : llvm::seq<int>(0, numInits)) {
+
+    // Permute the reduction dims as permuted by the partial result map.
+
+    int64_t numInits = linalgOp.getNumDpsInits();
+    SmallVector<Operation *> mergeOperations;
+    SmallVector<Value> replacements;
+    for (int idx : llvm::seq(numInits)) {
+      // linalg.reduce's iteration space is the tiled result's iteration space
+      // (and not the tiled operation's iteration space). To account for this,
+      // permute the reduction dimensions based on the partial result map of the
+      // tiled result.
+      AffineMap partialMap =
+          getPartialResultAffineMap(linalgOp, reductionDims, idx);
+      SmallVector<int64_t> partialReductionDims;
+      for (auto [resultNum, dimExpr] :
+           llvm::enumerate(partialMap.getResults())) {
+        unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+        if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+          partialReductionDims.push_back(resultNum);
+        }
+      }
+
+      Value partialResult = partialReduce[idx];
+      Value init = linalgOp.getDpsInits()[idx];
+
+      auto reduction = b.create<linalg::ReduceOp>(
+          loc, partialResult, init, partialReductionDims,
+          [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
             // Get the combiner op.
             SmallVector<Operation *, 4> combinerOps;
             matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
             Operation *clonedReductionOp = b.clone(*combinerOps[0]);
             // Combine the input at idx and output at numInits + idx.
-            clonedReductionOp->setOperand(0, inputs[idx]);
-            clonedReductionOp->setOperand(1, inputs[numInits + idx]);
-            // Yield.
-            yieldedValues.push_back(clonedReductionOp->getResult(0));
-          }
-          b.create<linalg::YieldOp>(loc, yieldedValues);
-        });
-    return MergeResult{
-        {reduction.getOperation()},
-        llvm::map_to_vector(reduction->getResults(),
-                            [](OpResult r) -> Value { return r; })};
+            clonedReductionOp->setOperand(0, inputs[0]);
+            clonedReductionOp->setOperand(1, inputs[1]);
+            b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+          });
+
+      mergeOperations.push_back(reduction);
+      replacements.push_back(reduction->getResult(0));
+    }
+
+    return MergeResult{mergeOperations, replacements};
+  }
+
+  LogicalResult getPartialResultTilePosition(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &resultOffsets,
+      SmallVector<OpFoldResult> &resultSizes,
+      ArrayRef<int> reductionDims) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    AffineMap partialMap =
+        getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
+    for (AffineExpr dimExpr : partialMap.getResults()) {
+      unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
+      resultSizes.push_back(sizes[dim]);
+
+      if (llvm::find(reductionDims, dim) != reductionDims.end()) {
+        // Reduction dims are reduced, and are always outputed in the same
+        // place. So use offset 0 for them.
+        resultOffsets.push_back(b.getIndexAttr(0));
+      } else {
+        resultOffsets.push_back(offsets[dim]);
+      }
+    }
+
+    return success();
   }
 };
 

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2277989bf8411b..b548f8ce8b560b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
                                     resultOffset, resultSize);
   case scf::SCFTilingOptions::ReductionTilingStrategy::
       PartialReductionOuterReduction: {
-    // TODO: This does not work for non identity accesses to the result tile.
-    // The proper fix is to add a getPartialResultTilePosition method to
-    // PartialReductionOpInterface.
-    resultOffset =
-        SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
-    for (size_t i = 0; i < offsets.size(); i++) {
-      resultSize.push_back(
-          tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+    auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+    if (!redOp) {
+      return rewriter.notifyMatchFailure(
+          op, "PartialReductionOuterReduction tiling strategy is only supported"
+              "for operations implementing PartialReductionOpInterface");
     }
-    return success();
+    // Get reduction dimensions.
+    // TODO: PartialReductionOpInterface should really query TilingInterface
+    // itself and find reduction dimensions.
+    SmallVector<int> reductionDims;
+    for (auto [idx, iteratorType] :
+         llvm::enumerate(op.getLoopIteratorTypes())) {
+      if (iteratorType == utils::IteratorType::reduction)
+        reductionDims.push_back(idx);
+    }
+    return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+                                              resultOffset, resultSize,
+                                              reductionDims);
+  }
   default:
     return rewriter.notifyMatchFailure(op,
                                        "unhandled reduction tiling strategy");
   }
-  }
 }
 
 static FailureOr<MergeResult>

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index cce4b4efa61c8b..9d34c80822d0e1 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -32,8 +32,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
 //     CHECK:     %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
@@ -81,13 +80,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 //     CHECK: func @reduction_tile_transpose
-//     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
-//     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
+//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x5xf32>
+//     CHECK:   linalg.fill {{.*}} : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   scf.for
-//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
 //     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
-//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
-//     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+//     CHECK:     scf.yield {{.*}} : tensor<?x5xf32>
 //     CHECK:   }
 //     CHECK:   linalg.reduce
 //     CHECK:   return
@@ -129,8 +128,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -183,9 +181,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
 // CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]], %[[D2]]) : tensor<?x?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
 // CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
@@ -243,8 +239,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 // CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
-//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
 //     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
@@ -422,8 +417,8 @@ func.func @reduction_tile_multiple_results(%arg0: tensor<?x?xf32>, %out: tensor<
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1, %12, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
-      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    %1, %12, %2, %3, %4, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
       transform.yield
   }
 }
@@ -444,4 +439,44 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
 // CHECK:       linalg.reduce
 // CHECK:         arith.addf
+// CHECK:       linalg.reduce
 // CHECK:         arith.maximumf
+
+// -----
+
+func.func @reduction_tile_multi_dim_transpose(%arg0: tensor<?x?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d2, d0)>],
+   iterator_types = ["parallel", "reduction", "parallel"]}
+   ins(%arg0 : tensor<?x?x?xf32>)
+   outs(%out : tensor<?x?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %42 = arith.addf %arg7, %arg9 : f32
+      linalg.yield %42 : f32
+    } -> tensor<?x?xf32>
+  return %red : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
+      by tile_sizes = [0, 5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+      transform.yield
+  }
+}
+
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+//     CHECK: func @reduction_tile_multi_dim_transpose
+//     CHECK:   tensor.empty(%{{.*}}) : tensor<?x?x5xf32>
+//     CHECK:   linalg.fill {{.*}} : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
+//     CHECK:   scf.for
+//     CHECK:     %[[K:.*]] = affine.min
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0, 0] [%[[D2:.*]], %[[D0:.*]], %[[K]]] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?x?xf32>
+//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?x?xf32>) outs(%[[EXT]] : tensor<?x?x?xf32>)
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0, 0] [%[[D2]], %[[D0]], %[[K]]] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x5xf32>
+//     CHECK:     scf.yield {{.*}} : tensor<?x?x5xf32>
+//     CHECK:   }
+//     CHECK:   linalg.reduce
+//     CHECK:   return


        


More information about the Mlir-commits mailing list