[Mlir-commits] [mlir] c9fb3c6 - [mlir][Tensor] Update ParallelInsertSlicOp semantics to match that of InsertSliceOp

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 4 02:37:51 PDT 2022


Author: Nicolas Vasilache
Date: 2022-07-04T02:37:46-07:00
New Revision: c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b

URL: https://github.com/llvm/llvm-project/commit/c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b
DIFF: https://github.com/llvm/llvm-project/commit/c9fb3c6ea6ccffee74f25d14ad5f9c74f45a715b.diff

LOG: [mlir][Tensor] Update ParallelInsertSlicOp semantics to match that of InsertSliceOp

This revision updates the op semantics to also allow rank-reducing behavior as well
as updates the implementation to reuse code between the sequential and the parallel
version of the op.

Depends on D128920

Differential Revision: https://reviews.llvm.org/D128985

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index d6001dc49ff9a..7813f887a6ae6 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -608,6 +608,11 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
     }
+    
+    /// The `dest` type is the same as the result type.
+    RankedTensorType getDestType() {
+      return getType();
+    }
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
@@ -1090,6 +1095,41 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
     Note that we cannot mark this operation as pure (NoSideEffects), even
     though it has no side effects, because it will get DCEd during
     canonicalization.
+
+    The parallel_insert_slice operation supports the following arguments:
+
+    * source: the tensor that is inserted.
+    * dest: the tensor into which the source tensor is inserted.
+    * offsets: tensor-rank number of offsets into the `dest` tensor into which
+               the slice is inserted.
+    * sizes: tensor-rank number of sizes which specify the sizes of the source
+             tensor type.
+    * strides: tensor-rank number of strides that specify subsampling in each
+               dimension.
+
+    The representation based on offsets, sizes and strides support a
+    partially-static specification via attributes specified through the
+    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+    sentinel value ShapedType::kDynamicSize and
+    ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+    a dynamic value.
+
+    After buffer allocation, the "parallel_insert_slice" op is expected to lower
+    into a memref.subview op.
+
+    A parallel_insert_slice operation may additionally specify insertion into a
+    tensor of higher rank than the source tensor, along dimensions that are 
+    statically known to be of size 1.
+    This rank-altering behavior is not required by the op semantics: this
+    flexibility allows to progressively drop unit dimensions while lowering
+    between 
diff erent flavors of ops on that operate on tensors.
+    The rank-altering behavior of tensor.parallel_insert_slice matches the 
+    rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
+
+    Verification in the rank-reduced case:
+    ======================================
+    The same verification discussion and mechanisms apply as for ExtractSliceOp.
+    Unlike ExtractSliceOp however, there is no need for a specific inference.
   }];
 
   let arguments = (ins
@@ -1117,6 +1157,10 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
       return getSource().getType().cast<RankedTensorType>();
     }
 
+    RankedTensorType getDestType() {
+      return getDest().getType().cast<RankedTensorType>();
+    }
+
     ParallelCombiningOpInterface getParallelCombiningParent() {
       return dyn_cast<ParallelCombiningOpInterface>(
         getOperation()->getParentOp());
@@ -1125,7 +1169,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
     /// Return the expected rank of each of the `static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
-      unsigned rank = getSourceType().getRank();
+      unsigned rank = getDestType().getRank();
       return {rank, rank, rank};
     }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3fa3c70e7f9ff..a9437634b285d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1123,9 +1123,9 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
 /// Verifier for ExtractSliceOp.
 LogicalResult ExtractSliceOp::verify() {
   // Verify result type against inferred type.
-  auto expectedType = ExtractSliceOp::inferResultType(
+  RankedTensorType expectedType = ExtractSliceOp::inferResultType(
       getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
-  auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
+  SliceVerificationResult result = isRankReducedType(expectedType, getType());
   return produceSliceErrorMsg(result, *this, expectedType);
 }
 
@@ -1487,17 +1487,18 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+/// Rank-reducing type verification for both InsertSliceOp and
+/// ParallelInsertSliceOp.
 static SliceVerificationResult
 verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
                     ArrayAttr staticOffsets, ArrayAttr staticSizes,
                     ArrayAttr staticStrides,
                     ShapedType *expectedType = nullptr) {
   // insert_slice is the inverse of extract_slice, use the same type inference.
-  auto expected = ExtractSliceOp::inferResultType(
-                      dstType, extractFromI64ArrayAttr(staticOffsets),
-                      extractFromI64ArrayAttr(staticSizes),
-                      extractFromI64ArrayAttr(staticStrides))
-                      .cast<ShapedType>();
+  RankedTensorType expected = ExtractSliceOp::inferResultType(
+      dstType, extractFromI64ArrayAttr(staticOffsets),
+      extractFromI64ArrayAttr(staticSizes),
+      extractFromI64ArrayAttr(staticStrides));
   if (expectedType)
     *expectedType = expected;
   return isRankReducedType(expected, srcType);
@@ -1506,7 +1507,7 @@ verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
 /// Verifier for InsertSliceOp.
 LogicalResult InsertSliceOp::verify() {
   ShapedType expectedType;
-  auto result =
+  SliceVerificationResult result =
       verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
                           getStaticSizes(), getStaticStrides(), &expectedType);
   return produceSliceErrorMsg(result, *this, expectedType);
@@ -1514,6 +1515,7 @@ LogicalResult InsertSliceOp::verify() {
 
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
 /// can mutate the second InsertSliceOp's destination to the first one's.
+/// This works similarly when the second op is a ParallelInsertSliceOp.
 ///
 /// Example:
 ///
@@ -1527,8 +1529,11 @@ LogicalResult InsertSliceOp::verify() {
 /// ```mlir
 ///   %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
 /// ```
-static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
-  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
+  auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
 
   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
   if (!prevInsertOp ||
@@ -1540,14 +1545,32 @@ static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
   return success();
 }
 
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
-  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
-      getSourceType() == getType() &&
-      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
-    return this->getSource();
-  if (succeeded(foldInsertAfterInsertSlice(*this)))
-    return getResult();
-  return OpFoldResult();
+/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
+/// type varies though so we wrap it in a FailureOr.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
+  if (insertOp.getSourceType().hasStaticShape() &&
+      insertOp.getDestType().hasStaticShape() &&
+      insertOp.getSourceType() == insertOp.getDestType() &&
+      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
+          insertOp, insertOp.getDestType())))
+    return static_cast<OpFoldResult>(insertOp.getSource());
+  if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
+    // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
+    // return OpFoldResult().
+    if (std::is_same<InsertOpTy, InsertSliceOp>::value)
+      return static_cast<OpFoldResult>(insertOp->getResult(0));
+    else
+      return OpFoldResult();
+  }
+  return failure();
+}
+
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
+  auto maybeOpFoldResult = foldInsertOp(*this, operands);
+  return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
 }
 
 LogicalResult InsertSliceOp::reifyResultShapes(
@@ -1562,12 +1585,15 @@ LogicalResult InsertSliceOp::reifyResultShapes(
 
 namespace {
 /// Pattern to rewrite a insert_slice op with constant arguments.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
 class InsertSliceOpConstantArgumentFolder final
-    : public OpRewritePattern<InsertSliceOp> {
+    : public OpRewritePattern<InsertOpTy> {
 public:
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+  using OpRewritePattern<InsertOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
     // No constant operand, just return.
     if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
@@ -1587,13 +1613,20 @@ class InsertSliceOpConstantArgumentFolder final
 
     // Create the new op in canonical form.
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
-        insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
+        insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
         mixedOffsets, mixedSizes, mixedStrides);
     Value toInsert = insertSliceOp.getSource();
-    if (sourceType != insertSliceOp.getSourceType())
+    if (sourceType != insertSliceOp.getSourceType()) {
+      OpBuilder::InsertionGuard g(rewriter);
+      // The only 
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+      // the the insertion point is just before the ParallelCombiningOp in the
+      // parallel case.
+      if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+        rewriter.setInsertionPoint(insertSliceOp->getParentOp());
       toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
                                                  sourceType, toInsert);
-    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+    }
+    rewriter.replaceOpWithNewOp<InsertOpTy>(
         insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
         mixedSizes, mixedStrides);
     return success();
@@ -1618,10 +1651,13 @@ class InsertSliceOpConstantArgumentFolder final
 /// Note: When folding a cast on the destination tensor, the result of the
 /// insert_slice operation is casted to ensure that the type of the result did
 /// not change.
-struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
+  using OpRewritePattern<InsertOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
     if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
           return matchPattern(operand, matchConstantIndex());
@@ -1643,24 +1679,27 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
     auto src =
         (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
     auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
-
-    auto srcType = src.getType().cast<ShapedType>();
-    auto dstType = dst.getType().cast<ShapedType>();
+    auto srcType = src.getType().template cast<ShapedType>();
+    auto dstType = dst.getType().template cast<ShapedType>();
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
                             insertSliceOp.getStaticSizes(),
                             insertSliceOp.getStaticStrides()) !=
         SliceVerificationResult::Success)
       return failure();
 
-    Value replacement = rewriter.create<InsertSliceOp>(
+    Operation *replacement = rewriter.create<InsertOpTy>(
         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
         insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
 
-    if (replacement.getType() != insertSliceOp.getType()) {
-      replacement = rewriter.create<tensor::CastOp>(
-          insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
+    // In the parallel case there is no result and so nothing to cast.
+    bool isParallelInsert =
+        std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
+    if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
+      replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+                                                    insertSliceOp.getDestType(),
+                                                    replacement->getResult(0));
     }
-    rewriter.replaceOp(insertSliceOp, replacement);
+    rewriter.replaceOp(insertSliceOp, replacement->getResults());
     return success();
   }
 };
@@ -1684,14 +1723,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
 ///   %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
 ///       : tensor<64x64xf32> into ...
 /// ```
+///
+/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
 struct InsertSliceOpSourceCastInserter final
-    : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+    : public OpRewritePattern<InsertOpTy> {
+  using OpRewritePattern<InsertOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
     RankedTensorType srcType = insertSliceOp.getSourceType();
-    if (srcType.getRank() != insertSliceOp.getType().getRank())
+    if (srcType.getRank() != insertSliceOp.getDestType().getRank())
       return failure();
     SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
                                      srcType.getShape().end());
@@ -1713,12 +1755,19 @@ struct InsertSliceOpSourceCastInserter final
     //   2) "More static" than srcType.
     //   3) Cast-compatible with srcType.
     // Insert the cast.
+    OpBuilder::InsertionGuard g(rewriter);
+    // The only 
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+    // the the insertion point is just before the ParallelCombiningOp in the
+    // parallel case.
+    if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+      rewriter.setInsertionPoint(insertSliceOp->getParentOp());
     Value cast = rewriter.create<tensor::CastOp>(
         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
-    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+    rewriter.replaceOpWithNewOp<InsertOpTy>(
         insertSliceOp, cast, insertSliceOp.getDest(),
         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
         insertSliceOp.getMixedStrides());
+    cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
     return success();
   }
 };
@@ -1726,8 +1775,9 @@ struct InsertSliceOpSourceCastInserter final
 
 void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
-              InsertSliceOpSourceCastInserter>(context);
+  results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
+              InsertSliceOpCastFolder<InsertSliceOp>,
+              InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
 }
 
 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
@@ -2234,7 +2284,12 @@ LogicalResult ParallelInsertSliceOp::verify() {
   if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
     return this->emitError("expected ParallelCombiningOpInterface parent, got:")
            << *(getOperation()->getParentOp());
-  return success();
+
+  ShapedType expectedType;
+  SliceVerificationResult result =
+      verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
+                          getStaticSizes(), getStaticStrides(), &expectedType);
+  return produceSliceErrorMsg(result, *this, expectedType);
 }
 
 namespace {
@@ -2263,51 +2318,37 @@ class ParallelInsertSliceOpConstantArgumentFolder final
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
+    auto sourceType =
+        tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
+            insertSliceOp.getSourceType().getRank(),
+            insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
+            mixedStrides);
+    Value toInsert = insertSliceOp.getSource();
+    if (sourceType != insertSliceOp.getSourceType()) {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+      toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+                                                 sourceType, toInsert);
+    }
     rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
-        insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
-        mixedOffsets, mixedSizes, mixedStrides);
+        insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
+        mixedSizes, mixedStrides);
     return success();
   }
 };
 } // namespace
 
-/// Fold a parallel_insert_slice source coming from a tensor.cast op.
-///
-/// Example:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-///   %1 = compute_some_tensor() : tensor<64xf32>
-///   %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
-///   scf.foreach_thread.perform_concurrently {
-///     scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
-///        tensor<?xf32> into tensor<128xf32>
-///   }
-/// }
-/// ```
-///
-/// is folded into:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-///   %1 = compute_some_tensor() : tensor<64xf32>
-///   scf.foreach_thread.perform_concurrently {
-///     scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
-///        tensor<64xf32> into tensor<128xf32>
-///   }
-/// }
-/// ```
 LogicalResult
 ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
                             SmallVectorImpl<OpFoldResult> &results) {
-  auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
-  if (!sourceCast)
-    return failure();
-  getSourceMutable().assign(sourceCast.getSource());
-  return success();
+  return foldInsertOp(*this, operands);
 }
 
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+  results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
+              InsertSliceOpCastFolder<ParallelInsertSliceOp>,
+              InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6196f4d8205aa..d07f3e894e242 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1429,23 +1429,25 @@ func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : inde
 // -----
 
 // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
-//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<?x?xf32>, 
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 
 //  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
 //  CHECK-SAME:     %[[num_threads:[0-9a-z]*]]: index
 func.func @canonicalize_parallel_insert_slice_indices(
-    %arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+    %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
     %num_threads : index) -> tensor<?x?xf32>
 {
   %cst = arith.constant 4.200000e+01 : f32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
 
+  //  CHECK-NOT: tensor.cast
   //      CHECK: scf.foreach_thread (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) -> (tensor<?x?xf32>) {
   // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
   // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][%[[tidx]], 0] [1, 5] [1, 1]
   %2 = scf.foreach_thread (%tidx) in (%num_threads)  -> (tensor<?x?xf32>) {
+    %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
     scf.foreach_thread.perform_concurrently {
-      tensor.parallel_insert_slice %arg0 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
+      tensor.parallel_insert_slice %3 into %arg1[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
     }
   }
   return %2 : tensor<?x?xf32>


        


More information about the Mlir-commits mailing list