[Mlir-commits] [mlir] [mlir][tensor] Fold rank increasing expand_shape into insert_slice (PR #93018)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu May 23 05:20:25 PDT 2024
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/93018
>From 90fbffb0374642f146b6cb377f2647f4ffda1521 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 17 May 2024 16:38:55 +0200
Subject: [PATCH 1/4] [mlir][tensor] Fold rank increasing expand_shape into
insert_slice
---
.../Tensor/Transforms/ReshapePatterns.cpp | 38 ++++++++++++--
.../Tensor/fold-reassociative-reshapes.mlir | 49 +++++++++++++++++++
2 files changed, 83 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index d40e5f33d2a73..824bae63f14c6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -79,12 +79,42 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Fold rank increasing expand_shape into insert_slice.
+template <typename OpTy>
+struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = insertSliceOp.getSource()
+ .template getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only fold away simple rank increasing expansion.
+ SliceVerificationResult res = isRankReducedType(
+ expandShapeOp.getResultType(), expandShapeOp.getSrcType());
+ if (res != SliceVerificationResult::Success) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "expected rank increasing expansion");
+ }
+
+ rewriter.modifyOpInPlace(insertSliceOp, [&]() {
+ insertSliceOp.setOperand(/*source=*/0, expandShapeOp.getSrc());
+ });
+ return success();
+ }
+};
} // namespace
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldExpandOfRankReducingExtract,
- FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
- FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
- patterns.getContext());
+ patterns
+ .add<FoldExpandOfRankReducingExtract,
+ FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+ FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
+ FoldRankIncreasingExpandIntoInsert<tensor::InsertSliceOp>,
+ FoldRankIncreasingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index d3ac6ce792f36..9e9c66f2d3123 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -54,3 +54,52 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
}
return %1 : tensor<?x?x?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @rank_increasing_insert_of_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[insert]]
+func.func @rank_increasing_insert_of_expand_shape(
+ %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
+ -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+ %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
+ : tensor<?x?xf32> into tensor<1x?x?xf32>
+ %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
+ : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+ return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @rank_increasing_parallel_insert_of_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+func.func @rank_increasing_parallel_insert_of_expand_shape(
+ %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
+ -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+ %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
+ : tensor<?x?xf32> into tensor<1x?x?xf32>
+ %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
+ : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+ }
+ }
+ return %1 : tensor<?x?x?x?xf32>
+}
>From af5fecea53423dd81d45c376f1ec3a89427716f7 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 May 2024 12:40:19 +0200
Subject: [PATCH 2/4] Expand test cases
---
.../Tensor/fold-reassociative-reshapes.mlir | 81 +++++++++++++++----
1 file changed, 67 insertions(+), 14 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 9e9c66f2d3123..d6d10df02c36c 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -57,48 +57,101 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
// -----
-// CHECK-LABEL: func @rank_increasing_insert_of_expand_shape(
+// CHECK-LABEL: func @insert_of_rank_increasing_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
-// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[insert]]
-func.func @rank_increasing_insert_of_expand_shape(
+func.func @insert_of_rank_increasing_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
- %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
- : tensor<?x?xf32> into tensor<1x?x?xf32>
- %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
- : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
+ : tensor<?x?xf32> into tensor<1x?x1x?xf32>
+ %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
+ : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
// -----
-// CHECK-LABEL: func @rank_increasing_parallel_insert_of_expand_shape(
+// CHECK-LABEL: func @insert_of_non_rank_increasing_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
-// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
-func.func @rank_increasing_parallel_insert_of_expand_shape(
+// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[insert]]
+func.func @insert_of_non_rank_increasing_expand_shape(
+ %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+ -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+ %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
+ : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @parallel_insert_of_rank_increasing_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+func.func @parallel_insert_of_rank_increasing_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
%sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
- %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
- : tensor<?x?xf32> into tensor<1x?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
+ : tensor<?x?xf32> into tensor<1x?x1x?xf32>
+ %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
+ : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
+ }
+ }
+ return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @parallel_insert_of_non_rank_increasing_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+func.func @parallel_insert_of_non_rank_increasing_expand_shape(
+ %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+ -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+ %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+ %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
%1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
scf.forall.in_parallel {
- tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
- : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+ tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
+ : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
}
}
return %1 : tensor<?x?x?x?xf32>
>From a94583b6a35abcdde886aea51323e517661f7e79 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 May 2024 14:15:17 +0200
Subject: [PATCH 3/4] Address comments
---
.../lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 824bae63f14c6..8dc1825d486ad 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -80,7 +80,8 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
}
};
-/// Fold rank increasing expand_shape into insert_slice.
+/// Fold expand_shape which only adds static dimensions of size `1`
+/// into insert_slice.
template <typename OpTy>
struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -92,16 +93,16 @@ struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
if (!expandShapeOp)
return failure();
- // Only fold away simple rank increasing expansion.
+ // Only fold away simple rank increasing expansion where all added
+ // dimensions have static size `1`.
SliceVerificationResult res = isRankReducedType(
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
- if (res != SliceVerificationResult::Success) {
+ if (res != SliceVerificationResult::Success)
return rewriter.notifyMatchFailure(insertSliceOp,
"expected rank increasing expansion");
- }
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
- insertSliceOp.setOperand(/*source=*/0, expandShapeOp.getSrc());
+ insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
});
return success();
}
>From c7dddc459a93cc75a8c4e8a7a6b1382789b0ef7f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 May 2024 14:19:55 +0200
Subject: [PATCH 4/4] Improve naming
---
.../Tensor/Transforms/ReshapePatterns.cpp | 19 +++++++++----------
.../Tensor/fold-reassociative-reshapes.mlir | 16 ++++++++--------
2 files changed, 17 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 8dc1825d486ad..6cf0f845f59db 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -83,7 +83,7 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
/// Fold expand_shape which only adds static dimensions of size `1`
/// into insert_slice.
template <typename OpTy>
-struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
+struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy insertSliceOp,
@@ -93,8 +93,8 @@ struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
if (!expandShapeOp)
return failure();
- // Only fold away simple rank increasing expansion where all added
- // dimensions have static size `1`.
+ // Only fold away simple expansion where all added dimensions have static
+ // size `1`.
SliceVerificationResult res = isRankReducedType(
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
if (res != SliceVerificationResult::Success)
@@ -111,11 +111,10 @@ struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<FoldExpandOfRankReducingExtract,
- FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
- FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
- FoldRankIncreasingExpandIntoInsert<tensor::InsertSliceOp>,
- FoldRankIncreasingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
- patterns.getContext());
+ patterns.add<FoldExpandOfRankReducingExtract,
+ FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+ FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
+ FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
+ FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index d6d10df02c36c..644d9a918f6ca 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -57,14 +57,14 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
// -----
-// CHECK-LABEL: func @insert_of_rank_increasing_expand_shape(
+// CHECK-LABEL: func @insert_of_padding_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[insert]]
-func.func @insert_of_rank_increasing_expand_shape(
+func.func @insert_of_padding_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
@@ -80,7 +80,7 @@ func.func @insert_of_rank_increasing_expand_shape(
// -----
-// CHECK-LABEL: func @insert_of_non_rank_increasing_expand_shape(
+// CHECK-LABEL: func @insert_of_non_padding_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
@@ -89,7 +89,7 @@ func.func @insert_of_rank_increasing_expand_shape(
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[insert]]
-func.func @insert_of_non_rank_increasing_expand_shape(
+func.func @insert_of_non_padding_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
@@ -105,13 +105,13 @@ func.func @insert_of_non_rank_increasing_expand_shape(
// -----
-// CHECK-LABEL: func @parallel_insert_of_rank_increasing_expand_shape(
+// CHECK-LABEL: func @parallel_insert_of_padding_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
-func.func @parallel_insert_of_rank_increasing_expand_shape(
+func.func @parallel_insert_of_padding_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
@@ -131,7 +131,7 @@ func.func @parallel_insert_of_rank_increasing_expand_shape(
// -----
-// CHECK-LABEL: func @parallel_insert_of_non_rank_increasing_expand_shape(
+// CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
@@ -139,7 +139,7 @@ func.func @parallel_insert_of_rank_increasing_expand_shape(
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
-func.func @parallel_insert_of_non_rank_increasing_expand_shape(
+func.func @parallel_insert_of_non_padding_expand_shape(
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
-> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list