[Mlir-commits] [mlir] 18b92c6 - [mlir][Linalg] Add a TileToForeachThread transform.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 19 04:59:19 PDT 2022


Author: Nicolas Vasilache
Date: 2022-07-19T04:56:11-07:00
New Revision: 18b92c66fe59a44f50bc211a418eaf48fe1cf7c1

URL: https://github.com/llvm/llvm-project/commit/18b92c66fe59a44f50bc211a418eaf48fe1cf7c1
DIFF: https://github.com/llvm/llvm-project/commit/18b92c66fe59a44f50bc211a418eaf48fe1cf7c1.diff

LOG: [mlir][Linalg] Add a TileToForeachThread transform.

This revision adds a new transformation to tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
tiling by `num_threads`.
If non-empty, the `threadDimMapping` is added as an attribute to the resulting `scf.foreach_thread`.
0-tile sizes (i.e. tile by the full size of the data) are used to encode
that a dimension is not tiled.

Differential Revision: https://reviews.llvm.org/D129577

Added: 
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 01d41057e2b62..ab5ef09e18c7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -14,6 +14,7 @@
 #include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
+class TilingInterface;
 namespace linalg {
 class GenericOp;
 class LinalgOp;

diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e1c7baf94d941..7a9c713554767 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -554,6 +554,60 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
   }];
 }
 
+def TileToForeachThreadOp :
+    Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
+      [FunctionalStyleTransformOpTrait,
+       MemoryEffectsOpInterface,
+       TransformEachOpTrait,
+       TransformOpInterface]> {
+  let description = [{
+    Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
+    tiling by `num_threads`.
+    If non-empty, the `thread_dim_mapping` is added as an attribute to the 
+    resulting `scf.foreach_thread`.
+    Zero tile sizes indicate that the dimension is not tiled, and can be thought
+    of as tiling by the full size of data.
+    It is the user's responsibility to ensure that `num_threads` is a valid 
+    tiling specification (i.e. that only tiles parallel dimensions, e.g. in the 
+    Linalg case).
+    
+    #### Return modes
+    
+    This operation ignores ops that do not implement the TilingInterface and
+    drops them in the return.
+
+    If all the operations referred to by the `target` PDLOperation tile
+    successfully, the transform succeeds.
+    Otherwise the transform silently fails.
+    
+    The 2 returned handles point to only the subset of successfully produced 
+    tiled operations, which can all be empty.
+    
+    These 2 returned handles point to:
+      - the new scf.foreach_thread op,
+      - the tiled op that implements TilingInterface.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   // TODO: dynamic number of threads.
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+                   OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
+  let results = (outs PDL_Operation:$foreach_thread_op,
+                      PDL_Operation:$tiled_op);
+
+  let assemblyFormat = [{
+    $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? 
+      attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::TilingInterface target, 
+        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformEachOpTrait, TransformOpInterface]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 79c40057e992e..68a41fd1b14de 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -517,6 +517,24 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
                       OpFoldResult targetSize, OpFoldResult divisor,
                       bool emitAssertions = true);
 
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
+/// tiling by `numThreads`.
+/// If non-empty, the `threadDimMapping` is added as an attribute to the
+/// resulting `scf.foreach_thread`.
+/// Zero tile sizes indicate that the dimension is not tiled, and can be thought
+/// of as tiling by the full size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a
+/// valid tiling specification (i.e. that only tiles parallel
+/// dimensions, e.g. in the Linalg case).
+struct ForeachThreadTilingResult {
+  Operation *tileOp;
+  Operation *tiledOp;
+};
+FailureOr<ForeachThreadTilingResult>
+tileToForeachThreadOp(OpBuilder &builder, TilingInterface op,
+                      ArrayRef<OpFoldResult> numThreads,
+                      ArrayRef<int64_t> threadDimMapping = {});
+
 /// All indices returned by IndexOp should be invariant with respect to tiling.
 /// Therefore, if an operation is tiled, we have to transform the indices
 /// accordingly, i.e. offset them by the values of the corresponding induction

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 78b11829bb4fb..b0c71f514a585 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -50,6 +50,9 @@ OpFoldResult getAsOpFoldResult(Value val);
 /// value. If this fails, return the original value.
 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
 
+/// Convert `arrayAttr` to a vector of OpFoldResult.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index acdfe77263ac6..71a2c7d453803 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -23,16 +24,6 @@ using namespace mlir;
 using namespace mlir::linalg;
 using namespace mlir::transform;
 
-/// Extracts a vector of int64_t from an array attribute. Asserts if the
-/// attribute contains values other than integers.
-static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
-  SmallVector<int64_t> result;
-  result.reserve(attr.size());
-  for (APInt value : attr.getAsValueRange<IntegerAttr>())
-    result.push_back(value.getSExtValue());
-  return result;
-}
-
 /// Extracts a vector of unsigned from an array attribute. Asserts if the
 /// attribute contains values other than intergers. May truncate.
 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
@@ -160,7 +151,8 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
            << "'" << sizesAttrName << "' attribute must be an array";
   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
   size_t numExpectedLoops =
-      sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
+      sizesArrayAttr.size() -
+      llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
     return failure();
@@ -171,8 +163,8 @@ DiagnosedSilenceableFailure
 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
                          mlir::transform::TransformState &state) {
   LinalgTilingAndFusionOptions fusionOptions;
-  fusionOptions.tileSizes = extractI64Array(getTileSizes());
-  fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
+  fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
+  fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
 
   LogicalResult result = applyTilingToAll(
       getOperation(), state.getPayloadOps(getTarget()),
@@ -209,7 +201,8 @@ void transform::FuseOp::print(OpAsmPrinter &p) {
 }
 
 LogicalResult transform::FuseOp::verify() {
-  SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
+  SmallVector<int64_t> permutation =
+      extractFromI64ArrayAttr(getTileInterchange());
   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
   if (!std::is_permutation(sequence.begin(), sequence.end(),
                            permutation.begin(), permutation.end())) {
@@ -327,7 +320,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
                              transform::TransformState &state) {
   // Convert the integer packing flags to booleans.
   SmallVector<bool> packPaddings;
-  for (int64_t packPadding : extractI64Array(getPackPaddings()))
+  for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
     packPaddings.push_back(static_cast<bool>(packPadding));
 
   // Convert the padding values to attributes.
@@ -362,13 +355,14 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
   SmallVector<SmallVector<int64_t>> transposePaddings;
   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
     transposePaddings.push_back(
-        extractI64Array(transposeVector.cast<ArrayAttr>()));
+        extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
 
   LinalgPaddingOptions paddingOptions;
   paddingOptions.setPaddingValues(paddingValues);
-  paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
+  paddingOptions.setPaddingDimensions(
+      extractFromI64ArrayAttr(getPaddingDimensions()));
   paddingOptions.setPackPaddings(packPaddings);
-  paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
+  paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
   paddingOptions.setTransposePaddings(transposePaddings);
 
   FailureOr<LinalgOp> result =
@@ -383,7 +377,8 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
 }
 
 LogicalResult transform::PadOp::verify() {
-  SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
+  SmallVector<int64_t> packPaddings =
+      extractFromI64ArrayAttr(getPackPaddings());
   if (any_of(packPaddings, [](int64_t packPadding) {
         return packPadding != 0 && packPadding != 1;
       })) {
@@ -393,7 +388,7 @@ LogicalResult transform::PadOp::verify() {
   }
 
   SmallVector<int64_t> paddingDimensions =
-      extractI64Array(getPaddingDimensions());
+      extractFromI64ArrayAttr(getPaddingDimensions());
   if (any_of(paddingDimensions,
              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
     return emitOpError()
@@ -401,7 +396,8 @@ LogicalResult transform::PadOp::verify() {
            << getPaddingDimensions();
   }
 
-  SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
+  SmallVector<int64_t> hoistPaddings =
+      extractFromI64ArrayAttr(getHoistPaddings());
   if (any_of(hoistPaddings,
              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
     return emitOpError()
@@ -657,7 +653,7 @@ DiagnosedSilenceableFailure
 transform::TileOp::apply(TransformResults &transformResults,
                          TransformState &state) {
   LinalgTilingOptions tilingOptions;
-  SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
+  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
 
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -743,7 +739,7 @@ transform::TileOp::apply(TransformResults &transformResults,
 
 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
   ValueRange dynamic = getDynamicSizes();
-  SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
+  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
   SmallVector<OpFoldResult> results;
   results.reserve(tileSizes.size());
   unsigned dynamicPos = 0;
@@ -773,7 +769,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 
   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
   size_t numExpectedLoops =
-      staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
+      staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
   return success();
 }
@@ -794,6 +790,29 @@ void transform::TileOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
+    TilingInterface target, SmallVectorImpl<Operation *> &results,
+    transform::TransformState &state) {
+  IRRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  auto maybeThreadDimMappingAttr = getThreadDimMapping();
+  FailureOr<ForeachThreadTilingResult> tilingResult =
+      linalg::tileToForeachThreadOp(
+          rewriter, target, getAsOpFoldResult(getNumThreads()),
+          maybeThreadDimMappingAttr
+              ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+              : ArrayRef<int64_t>{});
+  if (failed(tilingResult))
+    return emitDefaultSilenceableFailure(target);
+  rewriter.replaceOp(target, tilingResult->tileOp->getResults());
+  results.assign({tilingResult->tileOp, tilingResult->tiledOp});
+  return DiagnosedSilenceableFailure(success());
+}
+
 //===----------------------------------------------------------------------===//
 // VectorizeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 25eab5b8ecd56..0571ff5432afb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -169,6 +169,135 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
   return spec;
 }
 
+/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new
+/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location
+/// as `subsetExtractOp`.
+static void
+createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
+                                     tensor::ExtractSliceOp subsetExtractOp,
+                                     Value source, Value dest) {
+  b.create<tensor::ParallelInsertSliceOp>(
+      loc, source, dest, subsetExtractOp.getMixedOffsets(),
+      subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides());
+}
+
+/// Build an `affine_max` of all the `vals`.
+static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
+  return b.createOrFold<AffineMaxOp>(
+      loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
+      vals);
+}
+
+/// Build an `affine_min` of all the `vals`.
+static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
+  return b.createOrFold<AffineMinOp>(
+      loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
+      vals);
+}
+
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
+                              ArrayRef<OpFoldResult> numThreads,
+                              ArrayRef<int64_t> threadDimMapping) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(b);
+  SmallVector<Range> loopRanges = op.getIterationDomain(b);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+  // TODO: support `getTiledImplementation` with >1 produced tiled ops.
+  auto destOperands = op.getDestinationOperands(b);
+  if (destOperands.size() != 1)
+    return op->emitOpError("only single dest operand supported atm");
+
+  SmallVector<OpFoldResult> nonZeroNumThreads =
+      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 0);
+      }));
+  SmallVector<Value> materializedNonZeroNumThreads =
+      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+        ImplicitLocOpBuilder ilocb(loc, b);
+        return materializeOpFoldResult(ilocb, ofr);
+      }));
+
+  Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+  Operation *tiledOp = nullptr;
+  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+      loc, materializedNonZeroNumThreads, threadDimMapping,
+      [&](OpBuilder &b, Location loc, ValueRange threadIds) {
+        int64_t nLoops = loopRanges.size();
+        SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+        tiledOffsets.reserve(nLoops);
+        tiledSizes.reserve(nLoops);
+        for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
+             ++loopIdx) {
+          bool overflow = loopIdx >= numThreads.size();
+          bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+          // Degenerate case: take the whole domain.
+          if (overflow || isZero) {
+            tiledOffsets.push_back(loopRanges[loopIdx].offset);
+            tiledSizes.push_back(loopRanges[loopIdx].size);
+            continue;
+          }
+
+          // Tiled case: compute the offset and size.
+          AffineExpr i, j, M, N, O;
+          bindDims(b.getContext(), i, j);
+          bindSymbols(b.getContext(), M, N, O);
+          Value size = loopRanges[loopIdx].size;
+          Value offset = loopRanges[loopIdx].offset;
+          Value threadId = threadIds[threadIdIdx];
+          // TODO: more aggressive foldings.
+          // Symbolic fixed max size per thread.
+          // TODO: floor + 0/1 depending on case for better load-balancing.
+          Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
+              loc, M.ceilDiv(N),
+              ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
+          // Dynamic offset shifted by threadId * maxSizePerThread.
+          Value offsetPerThread = b.createOrFold<AffineApplyOp>(
+              loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
+          // Dynamic upper-bound depending on the threadId.
+          Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
+              loc, -i + M, ValueRange{offsetPerThread, size});
+          Value tileSizePerThread = buildMin(
+              b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
+          tiledOffsets.push_back(offsetPerThread);
+          // TODO: if tileSizePerThread <= 0 early exit.
+          tiledSizes.push_back(
+              buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
+          ++threadIdIdx;
+        }
+
+        SmallVector<Operation *> tiledOps =
+            op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
+                                      /*tileDestOperands=*/true);
+        assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+        tiledOp = tiledOps.front();
+
+        auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+        assert(tilingInterfaceOp &&
+               "Tiled op does not implement TilingInterface");
+
+        auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
+
+        // Create terminator with parallel subset insert operations.
+        auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
+        OpBuilder::InsertionGuard g(b);
+        b.setInsertionPointToStart(performConcurrentlyOp.getBody());
+        for (auto it :
+             llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
+                       destOperands)) {
+          createMatchingParallelSubsetInsertOp(
+              b, loc,
+              cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
+              std::get<1>(it), std::get<2>(it));
+        }
+      });
+  return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
+}
+
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 419aa46329b67..c650f3f5af295 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -65,6 +65,15 @@ SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
       llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
 }
 
+/// Convert `arrayAttr` to a vector of OpFoldResult.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
+  SmallVector<OpFoldResult> res;
+  res.reserve(arrayAttr.size());
+  for (Attribute a : arrayAttr)
+    res.push_back(a);
+  return res;
+}
+
 /// If ofr is a constant integer or an IntegerAttr, return the integer.
 Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   // Case 1: Check for Constant integer.

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
new file mode 100644
index 0000000000000..89f500df24638
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s
+
+// Offset per thread:
+// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
+// Per thread tile size.
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
+
+module {
+// CHECK-LABEL: matmul(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+  func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  //  CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+  //  CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+  //      CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor<?x?xf32>) {
+  //      CHECK:   %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+  //      CHECK:   %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+  //      CHECK:   %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+  //      CHECK:   %[[RES:.*]] = linalg.matmul
+  // CHECK-SAME:      ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+  // CHECK-SAME:     outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+  // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
+  // CHECK-SAME:       tensor<?x?xf32> into tensor<?x?xf32>
+  // CHECK-NEXT:   }
+  // CHECK-NEXT: } {thread_dim_mapping = [1, 0]}
+    %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                      outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+    return %0 : tensor<?x?xf32>
+  }
+
+  transform.with_pdl_patterns {
+  ^bb0(%arg0: !pdl.operation):
+    pdl.pattern @match_linalg_matmul : benefit(1) {
+      %0 = operands
+      %1 = types
+      %2 = operation "linalg.matmul"(%0 : !pdl.range<value>)  -> (%1 : !pdl.range<type>)
+      rewrite %2 with "transform.dialect"
+    }
+    transform.sequence %arg0 {
+    ^bb1(%arg1: !pdl.operation):
+      %0 = pdl_match @match_linalg_matmul in %arg1
+      %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0])
+    }
+  }
+}


        


More information about the Mlir-commits mailing list