[Mlir-commits] [mlir] 18b92c6 - [mlir][Linalg] Add a TileToForeachThread transform.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jul 19 04:59:19 PDT 2022
Author: Nicolas Vasilache
Date: 2022-07-19T04:56:11-07:00
New Revision: 18b92c66fe59a44f50bc211a418eaf48fe1cf7c1
URL: https://github.com/llvm/llvm-project/commit/18b92c66fe59a44f50bc211a418eaf48fe1cf7c1
DIFF: https://github.com/llvm/llvm-project/commit/18b92c66fe59a44f50bc211a418eaf48fe1cf7c1.diff
LOG: [mlir][Linalg] Add a TileToForeachThread transform.
This revision adds a new transformation to tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
tiling by `num_threads`.
If non-empty, the `threadDimMapping` is added as an attribute to the resulting `scf.foreach_thread`.
0-tile sizes (i.e. tile by the full size of the data) are used to encode
that a dimension is not tiled.
Differential Revision: https://reviews.llvm.org/D129577
Added:
mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 01d41057e2b62..ab5ef09e18c7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -14,6 +14,7 @@
#include "mlir/IR/OpImplementation.h"
namespace mlir {
+class TilingInterface;
namespace linalg {
class GenericOp;
class LinalgOp;
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e1c7baf94d941..7a9c713554767 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -554,6 +554,60 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
}];
}
+def TileToForeachThreadOp :
+ Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformEachOpTrait,
+ TransformOpInterface]> {
+ let description = [{
+ Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
+ tiling by `num_threads`.
+ If non-empty, the `thread_dim_mapping` is added as an attribute to the
+ resulting `scf.foreach_thread`.
+ Zero tile sizes indicate that the dimension is not tiled, and can be thought
+ of as tiling by the full size of data.
+ It is the user's responsibility to ensure that `num_threads` is a valid
+ tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+ Linalg case).
+
+ #### Return modes
+
+ This operation ignores ops that do not implement the TilingInterface and
+ drops them in the return.
+
+ If all the operations referred to by the `target` PDLOperation tile
+ successfully, the transform succeeds.
+ Otherwise the transform silently fails.
+
+ The 2 returned handles point to only the subset of successfully produced
+ tiled operations, which can all be empty.
+
+ These 2 returned handles point to:
+ - the new scf.foreach_thread op,
+ - the tiled op that implements TilingInterface.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ // TODO: dynamic number of threads.
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+ OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
+ let results = (outs PDL_Operation:$foreach_thread_op,
+ PDL_Operation:$tiled_op);
+
+ let assemblyFormat = [{
+ $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)?
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::TilingInterface target,
+ ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 79c40057e992e..68a41fd1b14de 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -517,6 +517,24 @@ computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
OpFoldResult targetSize, OpFoldResult divisor,
bool emitAssertions = true);
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
+/// tiling by `numThreads`.
+/// If non-empty, the `threadDimMapping` is added as an attribute to the
+/// resulting `scf.foreach_thread`.
+/// Zero tile sizes indicate that the dimension is not tiled, and can be thought
+/// of as tiling by the full size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a
+/// valid tiling specification (i.e. that only tiles parallel
+/// dimensions, e.g. in the Linalg case).
+struct ForeachThreadTilingResult {
+ Operation *tileOp;
+ Operation *tiledOp;
+};
+FailureOr<ForeachThreadTilingResult>
+tileToForeachThreadOp(OpBuilder &builder, TilingInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<int64_t> threadDimMapping = {});
+
/// All indices returned by IndexOp should be invariant with respect to tiling.
/// Therefore, if an operation is tiled, we have to transform the indices
/// accordingly, i.e. offset them by the values of the corresponding induction
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 78b11829bb4fb..b0c71f514a585 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -50,6 +50,9 @@ OpFoldResult getAsOpFoldResult(Value val);
/// value. If this fails, return the original value.
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
+/// Convert `arrayAttr` to a vector of OpFoldResult.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index acdfe77263ac6..71a2c7d453803 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -23,16 +24,6 @@ using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;
-/// Extracts a vector of int64_t from an array attribute. Asserts if the
-/// attribute contains values other than integers.
-static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
- SmallVector<int64_t> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getSExtValue());
- return result;
-}
-
/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
@@ -160,7 +151,8 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
<< "'" << sizesAttrName << "' attribute must be an array";
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
size_t numExpectedLoops =
- sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
+ sizesArrayAttr.size() -
+ llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
return failure();
@@ -171,8 +163,8 @@ DiagnosedSilenceableFailure
transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
LinalgTilingAndFusionOptions fusionOptions;
- fusionOptions.tileSizes = extractI64Array(getTileSizes());
- fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
+ fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
+ fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
LogicalResult result = applyTilingToAll(
getOperation(), state.getPayloadOps(getTarget()),
@@ -209,7 +201,8 @@ void transform::FuseOp::print(OpAsmPrinter &p) {
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
+ SmallVector<int64_t> permutation =
+ extractFromI64ArrayAttr(getTileInterchange());
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
@@ -327,7 +320,7 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
transform::TransformState &state) {
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
- for (int64_t packPadding : extractI64Array(getPackPaddings()))
+ for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
packPaddings.push_back(static_cast<bool>(packPadding));
// Convert the padding values to attributes.
@@ -362,13 +355,14 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
SmallVector<SmallVector<int64_t>> transposePaddings;
for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
transposePaddings.push_back(
- extractI64Array(transposeVector.cast<ArrayAttr>()));
+ extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
LinalgPaddingOptions paddingOptions;
paddingOptions.setPaddingValues(paddingValues);
- paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
+ paddingOptions.setPaddingDimensions(
+ extractFromI64ArrayAttr(getPaddingDimensions()));
paddingOptions.setPackPaddings(packPaddings);
- paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
+ paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
paddingOptions.setTransposePaddings(transposePaddings);
FailureOr<LinalgOp> result =
@@ -383,7 +377,8 @@ transform::PadOp::applyToOne(linalg::LinalgOp target,
}
LogicalResult transform::PadOp::verify() {
- SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
+ SmallVector<int64_t> packPaddings =
+ extractFromI64ArrayAttr(getPackPaddings());
if (any_of(packPaddings, [](int64_t packPadding) {
return packPadding != 0 && packPadding != 1;
})) {
@@ -393,7 +388,7 @@ LogicalResult transform::PadOp::verify() {
}
SmallVector<int64_t> paddingDimensions =
- extractI64Array(getPaddingDimensions());
+ extractFromI64ArrayAttr(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
return emitOpError()
@@ -401,7 +396,8 @@ LogicalResult transform::PadOp::verify() {
<< getPaddingDimensions();
}
- SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
+ SmallVector<int64_t> hoistPaddings =
+ extractFromI64ArrayAttr(getHoistPaddings());
if (any_of(hoistPaddings,
[](int64_t hoistPadding) { return hoistPadding < 0; })) {
return emitOpError()
@@ -657,7 +653,7 @@ DiagnosedSilenceableFailure
transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
- SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
+ SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
@@ -743,7 +739,7 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
- SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
+ SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
@@ -773,7 +769,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
- staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
+ staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
@@ -794,6 +790,29 @@ void transform::TileOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
+ TilingInterface target, SmallVectorImpl<Operation *> &results,
+ transform::TransformState &state) {
+ IRRewriter rewriter(getContext());
+ rewriter.setInsertionPoint(target);
+ auto maybeThreadDimMappingAttr = getThreadDimMapping();
+ FailureOr<ForeachThreadTilingResult> tilingResult =
+ linalg::tileToForeachThreadOp(
+ rewriter, target, getAsOpFoldResult(getNumThreads()),
+ maybeThreadDimMappingAttr
+ ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+ : ArrayRef<int64_t>{});
+ if (failed(tilingResult))
+ return emitDefaultSilenceableFailure(target);
+ rewriter.replaceOp(target, tilingResult->tileOp->getResults());
+ results.assign({tilingResult->tileOp, tilingResult->tiledOp});
+ return DiagnosedSilenceableFailure(success());
+}
+
//===----------------------------------------------------------------------===//
// VectorizeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 25eab5b8ecd56..0571ff5432afb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -169,6 +169,135 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
return spec;
}
+/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new
+/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location
+/// as `subsetExtractOp`.
+static void
+createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
+ tensor::ExtractSliceOp subsetExtractOp,
+ Value source, Value dest) {
+ b.create<tensor::ParallelInsertSliceOp>(
+ loc, source, dest, subsetExtractOp.getMixedOffsets(),
+ subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides());
+}
+
+/// Build an `affine_max` of all the `vals`.
+static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
+ return b.createOrFold<AffineMaxOp>(
+ loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
+ vals);
+}
+
+/// Build an `affine_min` of all the `vals`.
+static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
+ return b.createOrFold<AffineMinOp>(
+ loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
+ vals);
+}
+
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
+ ArrayRef<OpFoldResult> numThreads,
+ ArrayRef<int64_t> threadDimMapping) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(b);
+ SmallVector<Range> loopRanges = op.getIterationDomain(b);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+ // TODO: support `getTiledImplementation` with >1 produced tiled ops.
+ auto destOperands = op.getDestinationOperands(b);
+ if (destOperands.size() != 1)
+ return op->emitOpError("only single dest operand supported atm");
+
+ SmallVector<OpFoldResult> nonZeroNumThreads =
+ llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+ return !isConstantIntValue(ofr, 0);
+ }));
+ SmallVector<Value> materializedNonZeroNumThreads =
+ llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+ ImplicitLocOpBuilder ilocb(loc, b);
+ return materializeOpFoldResult(ilocb, ofr);
+ }));
+
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Operation *tiledOp = nullptr;
+ scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+ loc, materializedNonZeroNumThreads, threadDimMapping,
+ [&](OpBuilder &b, Location loc, ValueRange threadIds) {
+ int64_t nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(nLoops);
+ tiledSizes.reserve(nLoops);
+ for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
+ ++loopIdx) {
+ bool overflow = loopIdx >= numThreads.size();
+ bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ // Degenerate case: take the whole domain.
+ if (overflow || isZero) {
+ tiledOffsets.push_back(loopRanges[loopIdx].offset);
+ tiledSizes.push_back(loopRanges[loopIdx].size);
+ continue;
+ }
+
+ // Tiled case: compute the offset and size.
+ AffineExpr i, j, M, N, O;
+ bindDims(b.getContext(), i, j);
+ bindSymbols(b.getContext(), M, N, O);
+ Value size = loopRanges[loopIdx].size;
+ Value offset = loopRanges[loopIdx].offset;
+ Value threadId = threadIds[threadIdIdx];
+ // TODO: more aggressive foldings.
+ // Symbolic fixed max size per thread.
+ // TODO: floor + 0/1 depending on case for better load-balancing.
+ Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
+ loc, M.ceilDiv(N),
+ ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
+ // Dynamic offset shifted by threadId * maxSizePerThread.
+ Value offsetPerThread = b.createOrFold<AffineApplyOp>(
+ loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
+ // Dynamic upper-bound depending on the threadId.
+ Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
+ loc, -i + M, ValueRange{offsetPerThread, size});
+ Value tileSizePerThread = buildMin(
+ b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
+ tiledOffsets.push_back(offsetPerThread);
+ // TODO: if tileSizePerThread <= 0 early exit.
+ tiledSizes.push_back(
+ buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
+ ++threadIdIdx;
+ }
+
+ SmallVector<Operation *> tiledOps =
+ op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
+ /*tileDestOperands=*/true);
+ assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+ tiledOp = tiledOps.front();
+
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+ assert(tilingInterfaceOp &&
+ "Tiled op does not implement TilingInterface");
+
+ auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
+
+ // Create terminator with parallel subset insert operations.
+ auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(performConcurrentlyOp.getBody());
+ for (auto it :
+ llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
+ destOperands)) {
+ createMatchingParallelSubsetInsertOp(
+ b, loc,
+ cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
+ std::get<1>(it), std::get<2>(it));
+ }
+ });
+ return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
+}
+
// Insert a tile `source` into the destination tensor `dest`. The position at
// which the tile is inserted (as well as size of tile) is taken from a given
// ExtractSliceOp `sliceOp`.
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 419aa46329b67..c650f3f5af295 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -65,6 +65,15 @@ SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
}
+/// Convert `arrayAttr` to a vector of OpFoldResult.
+SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
+ SmallVector<OpFoldResult> res;
+ res.reserve(arrayAttr.size());
+ for (Attribute a : arrayAttr)
+ res.push_back(a);
+ return res;
+}
+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
new file mode 100644
index 0000000000000..89f500df24638
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s
+
+// Offset per thread:
+// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
+// Per thread tile size.
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
+
+module {
+// CHECK-LABEL: matmul(
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+ // CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: scf.foreach_thread ({{.*}}) in (%[[C10]], %[[C20]]) -> (tensor<?x?xf32>) {
+ // CHECK: %[[tA:.*]] = tensor.extract_slice %[[A]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[tB:.*]] = tensor.extract_slice %[[B]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[tC:.*]] = tensor.extract_slice %[[C]]{{.*}} : tensor<?x?xf32> to tensor<?x?xf32>
+ // CHECK: %[[RES:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
+ // CHECK-SAME: outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+ // CHECK-NEXT: scf.foreach_thread.perform_concurrently {
+ // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
+ // CHECK-SAME: tensor<?x?xf32> into tensor<?x?xf32>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: } {thread_dim_mapping = [1, 0]}
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+ return %0 : tensor<?x?xf32>
+ }
+
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @match_linalg_matmul : benefit(1) {
+ %0 = operands
+ %1 = types
+ %2 = operation "linalg.matmul"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ rewrite %2 with "transform.dialect"
+ }
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @match_linalg_matmul in %arg1
+ %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0])
+ }
+ }
+}
More information about the Mlir-commits
mailing list