[Mlir-commits] [mlir] c9fb3c6 - [mlir][Tensor] Update ParallelInsertSlicOp semantics to match that of InsertSliceOp
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 4 02:37:51 PDT 2022
Author: Nicolas Vasilache
Date: 2022-07-04T02:37:46-07:00
New Revision: c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b
URL: https://github.com/llvm/llvm-project/commit/c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b
DIFF: https://github.com/llvm/llvm-project/commit/c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b.diff
LOG: [mlir][Tensor] Update ParallelInsertSlicOp semantics to match that of InsertSliceOp
This revision updates the op semantics to also allow rank-reducing behavior as well
as updates the implementation to reuse code between the sequential and the parallel
version of the op.
Depends on D128920
Differential Revision: https://reviews.llvm.org/D128985
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index d6001dc49ff9a..7813f887a6ae6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -608,6 +608,11 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
RankedTensorType getType() {
return getResult().getType().cast<RankedTensorType>();
}
+
+ /// The `dest` type is the same as the result type.
+ RankedTensorType getDestType() {
+ return getType();
+ }
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
@@ -1090,6 +1095,41 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
Note that we cannot mark this operation as pure (NoSideEffects), even
though it has no side effects, because it will get DCEd during
canonicalization.
+
+ The parallel_insert_slice operation supports the following arguments:
+
+ * source: the tensor that is inserted.
+ * dest: the tensor into which the source tensor is inserted.
+ * offsets: tensor-rank number of offsets into the `dest` tensor into which
+ the slice is inserted.
+ * sizes: tensor-rank number of sizes which specify the sizes of the source
+ tensor type.
+ * strides: tensor-rank number of strides that specify subsampling in each
+ dimension.
+
+ The representation based on offsets, sizes and strides support a
+ partially-static specification via attributes specified through the
+ `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+ sentinel value ShapedType::kDynamicSize and
+ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+ a dynamic value.
+
+ After buffer allocation, the "parallel_insert_slice" op is expected to lower
+ into a memref.subview op.
+
+ A parallel_insert_slice operation may additionally specify insertion into a
+ tensor of higher rank than the source tensor, along dimensions that are
+ statically known to be of size 1.
+ This rank-altering behavior is not required by the op semantics: this
+ flexibility allows to progressively drop unit dimensions while lowering
+ between
diff erent flavors of ops on that operate on tensors.
+ The rank-altering behavior of tensor.parallel_insert_slice matches the
+ rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
+
+ Verification in the rank-reduced case:
+ ======================================
+ The same verification discussion and mechanisms apply as for ExtractSliceOp.
+ Unlike ExtractSliceOp however, there is no need for a specific inference.
}];
let arguments = (ins
@@ -1117,6 +1157,10 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
return getSource().getType().cast<RankedTensorType>();
}
+ RankedTensorType getDestType() {
+ return getDest().getType().cast<RankedTensorType>();
+ }
+
ParallelCombiningOpInterface getParallelCombiningParent() {
return dyn_cast<ParallelCombiningOpInterface>(
getOperation()->getParentOp());
@@ -1125,7 +1169,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
/// Return the expected rank of each of the `static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned rank = getSourceType().getRank();
+ unsigned rank = getDestType().getRank();
return {rank, rank, rank};
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3fa3c70e7f9ff..a9437634b285d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1123,9 +1123,9 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type.
- auto expectedType = ExtractSliceOp::inferResultType(
+ RankedTensorType expectedType = ExtractSliceOp::inferResultType(
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
- auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
+ SliceVerificationResult result = isRankReducedType(expectedType, getType());
return produceSliceErrorMsg(result, *this, expectedType);
}
@@ -1487,17 +1487,18 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+/// Rank-reducing type verification for both InsertSliceOp and
+/// ParallelInsertSliceOp.
static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
ArrayAttr staticOffsets, ArrayAttr staticSizes,
ArrayAttr staticStrides,
ShapedType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type inference.
- auto expected = ExtractSliceOp::inferResultType(
- dstType, extractFromI64ArrayAttr(staticOffsets),
- extractFromI64ArrayAttr(staticSizes),
- extractFromI64ArrayAttr(staticStrides))
- .cast<ShapedType>();
+ RankedTensorType expected = ExtractSliceOp::inferResultType(
+ dstType, extractFromI64ArrayAttr(staticOffsets),
+ extractFromI64ArrayAttr(staticSizes),
+ extractFromI64ArrayAttr(staticStrides));
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
@@ -1506,7 +1507,7 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
ShapedType expectedType;
- auto result =
+ SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
return produceSliceErrorMsg(result, *this, expectedType);
@@ -1514,6 +1515,7 @@ LogicalResult InsertSliceOp::verify() {
/// If we have two consecutive InsertSliceOp writing to the same slice, we
/// can mutate the second InsertSliceOp's destination to the first one's.
+/// This works similarly when the second op is a ParallelInsertSliceOp.
///
/// Example:
///
@@ -1527,8 +1529,11 @@ LogicalResult InsertSliceOp::verify() {
/// ```mlir
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
/// ```
-static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
- auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
+ auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (!prevInsertOp ||
@@ -1540,14 +1545,32 @@ static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
return success();
}
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
- if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
- getSourceType() == getType() &&
- succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
- return this->getSource();
- if (succeeded(foldInsertAfterInsertSlice(*this)))
- return getResult();
- return OpFoldResult();
+/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
+/// type varies though so we wrap it in a FailureOr.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
+ if (insertOp.getSourceType().hasStaticShape() &&
+ insertOp.getDestType().hasStaticShape() &&
+ insertOp.getSourceType() == insertOp.getDestType() &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
+ insertOp, insertOp.getDestType())))
+ return static_cast<OpFoldResult>(insertOp.getSource());
+ if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
+ // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
+ // return OpFoldResult().
+ if (std::is_same<InsertOpTy, InsertSliceOp>::value)
+ return static_cast<OpFoldResult>(insertOp->getResult(0));
+ else
+ return OpFoldResult();
+ }
+ return failure();
+}
+
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
+ auto maybeOpFoldResult = foldInsertOp(*this, operands);
+ return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
}
LogicalResult InsertSliceOp::reifyResultShapes(
@@ -1562,12 +1585,15 @@ LogicalResult InsertSliceOp::reifyResultShapes(
namespace {
/// Pattern to rewrite a insert_slice op with constant arguments.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
class InsertSliceOpConstantArgumentFolder final
- : public OpRewritePattern<InsertSliceOp> {
+ : public OpRewritePattern<InsertOpTy> {
public:
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
// No constant operand, just return.
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
@@ -1587,13 +1613,20 @@ class InsertSliceOpConstantArgumentFolder final
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
+ insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
Value toInsert = insertSliceOp.getSource();
- if (sourceType != insertSliceOp.getSourceType())
+ if (sourceType != insertSliceOp.getSourceType()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ // The only
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+ // the the insertion point is just before the ParallelCombiningOp in the
+ // parallel case.
+ if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
sourceType, toInsert);
- rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ }
+ rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
mixedSizes, mixedStrides);
return success();
@@ -1618,10 +1651,13 @@ class InsertSliceOpConstantArgumentFolder final
/// Note: When folding a cast on the destination tensor, the result of the
/// insert_slice operation is casted to ensure that the type of the result did
/// not change.
-struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
@@ -1643,24 +1679,27 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
auto src =
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
-
- auto srcType = src.getType().cast<ShapedType>();
- auto dstType = dst.getType().cast<ShapedType>();
+ auto srcType = src.getType().template cast<ShapedType>();
+ auto dstType = dst.getType().template cast<ShapedType>();
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
insertSliceOp.getStaticSizes(),
insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
- Value replacement = rewriter.create<InsertSliceOp>(
+ Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- if (replacement.getType() != insertSliceOp.getType()) {
- replacement = rewriter.create<tensor::CastOp>(
- insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
+ // In the parallel case there is no result and so nothing to cast.
+ bool isParallelInsert =
+ std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
+ if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
+ replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+ insertSliceOp.getDestType(),
+ replacement->getResult(0));
}
- rewriter.replaceOp(insertSliceOp, replacement);
+ rewriter.replaceOp(insertSliceOp, replacement->getResults());
return success();
}
};
@@ -1684,14 +1723,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
/// : tensor<64x64xf32> into ...
/// ```
+///
+/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
struct InsertSliceOpSourceCastInserter final
- : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+ : public OpRewritePattern<InsertOpTy> {
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType srcType = insertSliceOp.getSourceType();
- if (srcType.getRank() != insertSliceOp.getType().getRank())
+ if (srcType.getRank() != insertSliceOp.getDestType().getRank())
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
srcType.getShape().end());
@@ -1713,12 +1755,19 @@ struct InsertSliceOpSourceCastInserter final
// 2) "More static" than srcType.
// 3) Cast-compatible with srcType.
// Insert the cast.
+ OpBuilder::InsertionGuard g(rewriter);
+ // The only
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+ // the the insertion point is just before the ParallelCombiningOp in the
+ // parallel case.
+ if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
- rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
+ cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
return success();
}
};
@@ -1726,8 +1775,9 @@ struct InsertSliceOpSourceCastInserter final
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
- InsertSliceOpSourceCastInserter>(context);
+ results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
+ InsertSliceOpCastFolder<InsertSliceOp>,
+ InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
}
Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
@@ -2234,7 +2284,12 @@ LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
- return success();
+
+ ShapedType expectedType;
+ SliceVerificationResult result =
+ verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
+ getStaticSizes(), getStaticStrides(), &expectedType);
+ return produceSliceErrorMsg(result, *this, expectedType);
}
namespace {
@@ -2263,51 +2318,37 @@ class ParallelInsertSliceOpConstantArgumentFolder final
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
+ auto sourceType =
+ tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
+ insertSliceOp.getSourceType().getRank(),
+ insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
+ mixedStrides);
+ Value toInsert = insertSliceOp.getSource();
+ if (sourceType != insertSliceOp.getSourceType()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+ sourceType, toInsert);
+ }
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
- insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
- mixedOffsets, mixedSizes, mixedStrides);
+ insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
+ mixedSizes, mixedStrides);
return success();
}
};
} // namespace
-/// Fold a parallel_insert_slice source coming from a tensor.cast op.
-///
-/// Example:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-/// %1 = compute_some_tensor() : tensor<64xf32>
-/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
-/// scf.foreach_thread.perform_concurrently {
-/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
-/// tensor<?xf32> into tensor<128xf32>
-/// }
-/// }
-/// ```
-///
-/// is folded into:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-/// %1 = compute_some_tensor() : tensor<64xf32>
-/// scf.foreach_thread.perform_concurrently {
-/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
-/// tensor<64xf32> into tensor<128xf32>
-/// }
-/// }
-/// ```
LogicalResult
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
- if (!sourceCast)
- return failure();
- getSourceMutable().assign(sourceCast.getSource());
- return success();
+ return foldInsertOp(*this, operands);
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+ results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
+ InsertSliceOpCastFolder<ParallelInsertSliceOp>,
+ InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6196f4d8205aa..d07f3e894e242 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1429,23 +1429,25 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
// -----
// CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
-// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
// CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
func.func @canonicalize_parallel_insert_slice_indices(
- %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+ %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
%num_threads : index) -> tensor<?x?xf32>
{
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ // CHECK-NOT: tensor.cast
// CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
// CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
%2 = scf.foreach_thread (%tidx) in (%num_threads) -> (tensor<?x?xf32>) {
+ %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
scf.foreach_thread.perform_concurrently {
- tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
+ tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
}
}
return %2 : tensor<?x?xf32>
More information about the Mlir-commits
mailing list