[Mlir-commits] [mlir] 41849a9 - [mlir][Linalg] Avoid changing the rank of the result in canonicalizations of subtensor.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 28 11:33:34 PDT 2021


Author: MaheshRavishankar
Date: 2021-04-28T11:33:26-07:00
New Revision: 41849a91956755e15591240816d1d0c5ec402895

URL: https://github.com/llvm/llvm-project/commit/41849a91956755e15591240816d1d0c5ec402895
DIFF: https://github.com/llvm/llvm-project/commit/41849a91956755e15591240816d1d0c5ec402895.diff

LOG: [mlir][Linalg] Avoid changing the rank of the result in canonicalizations of subtensor.

Canonicalizations for subtensor operations defaulted to use the
rank-reduced version of the operation, but the cast inserted to get
back the original type would be illegal if the rank was actually
reduced. Instead make the canonicalization not reduce the rank of the
operation.

Differential Revision: https://reviews.llvm.org/D101258

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
index 79d8f5569f5a4..c9a2f5ae13d5e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
@@ -35,7 +35,7 @@ void getPositionsOfShapeOne(unsigned rank, ArrayRef<int64_t> shape,
                             llvm::SmallDenseSet<unsigned> &dimsToProject);
 
 /// Pattern to rewrite a subview op with constant arguments.
-template <typename OpType, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     : public OpRewritePattern<OpType> {
 public:
@@ -59,8 +59,12 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    auto newOp = rewriter.create<OpType>(op.getLoc(), op.source(), mixedOffsets,
-                                         mixedSizes, mixedStrides);
+    ResultTypeFunc resultTypeFunc;
+    auto resultType =
+        resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
+    auto newOp =
+        rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
+                                mixedOffsets, mixedSizes, mixedStrides);
     CastOpFunc func;
     func(rewriter, op, newOp);
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1ac00022e2327..57c1b1581a232 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1859,6 +1859,26 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
   return res;
 }
 
+/// Infer the canonical type of the result of a subview operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static MemRefType
+getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
+                              ArrayRef<OpFoldResult> mixedOffsets,
+                              ArrayRef<OpFoldResult> mixedSizes,
+                              ArrayRef<OpFoldResult> mixedStrides) {
+  auto resultType =
+      SubViewOp::inferRankReducedResultType(
+          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+          .cast<MemRefType>();
+  if (resultType.getRank() != resultRank) {
+    resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
+                                            mixedSizes, mixedStrides)
+                     .cast<MemRefType>();
+  }
+  return resultType;
+}
+
 namespace {
 /// Pattern to rewrite a subview op with MemRefCast arguments.
 /// This essentially pushes memref.cast past its consuming subview when
@@ -1898,7 +1918,7 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    auto resultType = SubViewOp::inferRankReducedResultType(
+    auto resultType = getCanonicalSubViewResultType(
         subViewOp.getType().getRank(),
         castOp.source().getType().cast<MemRefType>(),
         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
@@ -1914,6 +1934,17 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
 };
 } // namespace
 
+/// Return the canonical type of the result of a subview.
+struct SubViewReturnTypeCanonicalizer {
+  MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
+                        ArrayRef<OpFoldResult> mixedSizes,
+                        ArrayRef<OpFoldResult> mixedStrides) {
+    return getCanonicalSubViewResultType(op.getType().getRank(),
+                                         op.getSourceType(), mixedOffsets,
+                                         mixedSizes, mixedStrides);
+  }
+};
+
 /// A canonicalizer wrapper to replace SubViewOps.
 struct SubViewCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
@@ -1923,9 +1954,10 @@ struct SubViewCanonicalizer {
 
 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
-                  SubViewOp, SubViewCanonicalizer>,
-              SubViewOpMemRefCastFolder>(context);
+  results
+      .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+               SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
+           SubViewOpMemRefCastFolder>(context);
 }
 
 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index bd1fef099e668..ae76966ba25d6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -45,10 +45,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
   // the subview op with load even if the offsets have been canonicalized
   // away.
   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
+  if (opRanges.size() != indices.size()) {
+    // For the rank-reduced cases, we can only handle the folding when the
+    // offset is zero, size is 1 and stride is 1.
+    return failure();
+  }
   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
-  assert(opRanges.size() == indices.size() &&
-         "expected as many indices as rank of subview op result type");
 
   // New indices for the load are the current indices * subview_stride +
   // subview_offset.

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index b538cbabd80fb..9260e2757387c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1917,6 +1917,25 @@ static LogicalResult verify(SubTensorOp op) {
   return produceSubTensorErrorMsg(result, op, expectedType);
 }
 
+/// Infer the canonical type of the result of a subtensor operation. Returns a
+/// type with rank `resultRank` that is either the rank of the rank-reduced
+/// type, or the non-rank-reduced type.
+static RankedTensorType getCanonicalSubTensorResultType(
+    unsigned resultRank, RankedTensorType sourceType,
+    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+    ArrayRef<OpFoldResult> mixedStrides) {
+  auto resultType =
+      SubTensorOp::inferRankReducedResultType(
+          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
+          .cast<RankedTensorType>();
+  if (resultType.getRank() != resultRank) {
+    resultType = SubTensorOp::inferResultType(sourceType, mixedOffsets,
+                                              mixedSizes, mixedStrides)
+                     .cast<RankedTensorType>();
+  }
+  return resultType;
+}
+
 namespace {
 /// Pattern to rewrite a subtensor op with tensor::Cast arguments.
 /// This essentially pushes memref_cast past its consuming subtensor when
@@ -1951,13 +1970,9 @@ class SubTensorOpCastFolder final : public OpRewritePattern<SubTensorOp> {
     if (!canFoldIntoConsumerOp(castOp))
       return failure();
 
-    /// Deduce the resultType of SubTensorOp with `inferRankReducedResultType`
-    /// on the cast source operand type and the SubTensorOp static information.
-    /// This is the resulting type if the tensor::CastOp were folded and
-    /// rank-reduced to the desired result rank.
-    auto resultType = SubTensorOp::inferRankReducedResultType(
-        subTensorOp.getType().getRank(),
-        castOp.source().getType().cast<RankedTensorType>(),
+    /// Deduce the type of the result to use for the canonicalized operation.
+    RankedTensorType resultType = getCanonicalSubTensorResultType(
+        subTensorOp.getType().getRank(), subTensorOp.getSourceType(),
         subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(),
         subTensorOp.getMixedStrides());
     Value newSubTensor = rewriter.create<SubTensorOp>(
@@ -1972,6 +1987,18 @@ class SubTensorOpCastFolder final : public OpRewritePattern<SubTensorOp> {
 };
 } // namespace
 
+/// Return the canonical type of the result of a subtensor.
+struct SubTensorReturnTypeCanonicalizer {
+  RankedTensorType operator()(SubTensorOp op,
+                              ArrayRef<OpFoldResult> mixedOffsets,
+                              ArrayRef<OpFoldResult> mixedSizes,
+                              ArrayRef<OpFoldResult> mixedStrides) {
+    return getCanonicalSubTensorResultType(op.getType().getRank(),
+                                           op.getSourceType(), mixedOffsets,
+                                           mixedSizes, mixedStrides);
+  }
+};
+
 /// A canonicalizer wrapper to replace SubTensorOps.
 struct SubTensorCanonicalizer {
   void operator()(PatternRewriter &rewriter, SubTensorOp op,
@@ -1987,7 +2014,8 @@ struct SubTensorCanonicalizer {
 void SubTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
-                  SubTensorOp, SubTensorCanonicalizer>,
+                  SubTensorOp, SubTensorReturnTypeCanonicalizer,
+                  SubTensorCanonicalizer>,
               SubTensorOpCastFolder>(context);
 }
 
@@ -2093,22 +2121,9 @@ class SubTensorInsertOpConstantArgumentFolder final
     canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
 
     // Create the new op in canonical form.
-    Value source = subTensorInsertOp.source();
-    RankedTensorType sourceType = source.getType().cast<RankedTensorType>();
-    SmallVector<int64_t, 4> shape = llvm::to_vector<4>(
-        llvm::map_range(mixedSizes, [](OpFoldResult valueOrAttr) -> int64_t {
-          if (auto attr = valueOrAttr.dyn_cast<Attribute>())
-            return attr.cast<IntegerAttr>().getInt();
-          return ShapedType::kDynamicSize;
-        }));
-    RankedTensorType newSourceType =
-        RankedTensorType::get(shape, sourceType.getElementType());
-    Location loc = subTensorInsertOp.getLoc();
-    if (sourceType != newSourceType)
-      source = rewriter.create<tensor::CastOp>(loc, newSourceType, source);
     rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
-        subTensorInsertOp, source, subTensorInsertOp.dest(), mixedOffsets,
-        mixedSizes, mixedStrides);
+        subTensorInsertOp, subTensorInsertOp.source(), subTensorInsertOp.dest(),
+        mixedOffsets, mixedSizes, mixedStrides);
     return success();
   }
 };
@@ -2213,7 +2228,6 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
                    SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
                    SmallVectorImpl<Type> &caseOperandTypes,
                    DenseIntElementsAttr &caseOperandOffsets) {
-
   if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
       failed(parser.parseSuccessor(defaultDestination)))
     return failure();
@@ -2457,7 +2471,6 @@ static LogicalResult simplifyConstSwitchValue(SwitchOp op,
 /// ]
 static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
                                                PatternRewriter &rewriter) {
-
   SmallVector<Block *> newCaseDests;
   SmallVector<ValueRange> newCaseOperands;
   SmallVector<SmallVector<Value>> argStorage;

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7ff5c6f3fe4d5..0b0308f107cc7 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -62,3 +62,70 @@ func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, st
   %1 = memref.buffer_cast %0 : memref<?xf32, offset: ?, strides: [1]>
   return %1 : memref<?xf32, offset: ?, strides: [1]>
 }
+
+// -----
+
+// CHECK-LABEL: func @subview_of_memcast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
+    memref<?x?x16x32xi8> to
+    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}
+
+// -----
+
+// CHECK-LABEL: func @subview_of_static_full_size
+// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
+// CHECK-NOT: memref.subview
+// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
+func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
+  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
+  return %0 : memref<4x6x16x32xi8>
+}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>
+func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> memref<?x?x?xf32, #map0>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map0>
+  return %0 : memref<?x?x?xf32, #map0>
+}
+// CHECK-LABEL: func @subview_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x1x?xf32
+//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> memref<?x?xf32, #map0>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, #map0>
+  return %0 : memref<?x?xf32, #map0>
+}
+// CHECK-LABEL: func @rank_reducing_subview_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : memref<?x?x?xf32> to memref<4x?xf32
+//       CHECK:   %[[RESULT:.+]] = memref.cast %[[SUBVIEW]]
+//       CHEKC:   return %[[RESULT]]

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 908814bc46a14..e2b5e7b445b9b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -154,30 +154,41 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
 
 // -----
 
-// CHECK-LABEL: func @subview_of_memcast
-//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
-//       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
-//       CHECK:   %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
-//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
-func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
-  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
-  %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
-  %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] :
-    memref<?x?x16x32xi8> to
-    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
-  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+func @subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
+// CHECK-LABEL: func @subtensor_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+//       CHEKC:   return %[[RESULT]]
 
 // -----
 
-// CHECK-LABEL: func @subview_of_static_full_size
-// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8>
-// CHECK-NOT: memref.subview
-// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8>
-func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> {
-  %0 = memref.subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8>
-  return %0 : memref<4x6x16x32xi8>
+func @rank_reducing_subtensor_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index) -> tensor<?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
+// CHECK-LABEL: func @rank_reducing_subtensor_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
+//  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SUBTENSOR]]
+//       CHEKC:   return %[[RESULT]]
 
 // -----
 
@@ -232,7 +243,89 @@ func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<
 
 // -----
 
-func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
+func @subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @subtensor_to_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_insert_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor_insert %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @rank_reducing_subtensor_to_subtensor_insert_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+    %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+{
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %0 = subtensor %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
+  %1 = subtensor_insert %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @rank_reducing_subtensor_to_subtensor_insert_canonicalize
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+//       CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[ARG0]]
+//  CHECK-SAME:     [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:     : tensor<?x?x?xf32> to tensor<4x?xf32>
+//       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[SUBTENSOR]] into %[[ARG3]]
+//  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
+//  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
+//       CHEKC:   return %[[RESULT]]
+
+// -----
+
+func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -247,7 +340,7 @@ func @subtensor_canonicalize(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
   %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32>
   return %3 : tensor<?x?xi32>
 }
-// CHECK-LABEL: func @subtensor_canonicalize
+// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast
 //       CHECK:   %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
 //  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
 //       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]


        


More information about the Mlir-commits mailing list