[Mlir-commits] [mlir] [mlir] [tensor] Add patterns to remove whole slicing of tensors (PR #107046)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 2 21:20:16 PDT 2024
https://github.com/Menooker created https://github.com/llvm/llvm-project/pull/107046
Eliminate the redundant `tensor.extract_slice` and `tensor.insert_slice` when the slice size is proved to be the same as the source tensor. Dynamic shapes are also supported.
Examples of the extract/insert to be removed:
```
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
```
>From 92d7879dc2b2b1666f0faf2bb7b5f43b411e3760 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 3 Sep 2024 12:09:48 +0800
Subject: [PATCH 1/2] [mlir] [tensor] Add pattern to remove whole slicing of
tensors
---
.../Dialect/Tensor/Transforms/Transforms.h | 4 +
.../Dialect/Tensor/Transforms/CMakeLists.txt | 1 +
.../EliminateWholeSlicePatterns.cpp | 98 +++++++++
.../Tensor/eliminate-whole-slicing.mlir | 194 ++++++++++++++++++
.../Dialect/Tensor/TestTensorTransforms.cpp | 14 ++
5 files changed, 311 insertions(+)
create mode 100644 mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
create mode 100644 mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..9b94a98bcde36a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,10 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);
+/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+void populateEliminateWholeSlicingPatterns(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index ce32dea09bb0b5..d5bbedd13e7acc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
SubsetInsertionOpInterfaceImpl.cpp
+ EliminateWholeSlicePatterns.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
new file mode 100644
index 00000000000000..52ca6a9e6f65f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
@@ -0,0 +1,98 @@
+//===- EliminateWholeSlicePatterns.cpp - Patterns to remove whole slices --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+bool checkEliminateOK(PatternRewriter &rewriter,
+ OffsetSizeAndStrideOpInterface sliceOp,
+ mlir::TypedValue<mlir::RankedTensorType> smallerTensor,
+ mlir::TypedValue<mlir::RankedTensorType> largerTensor) {
+ auto srcType = largerTensor.getType();
+ auto resultType = smallerTensor.getType();
+ if (!isSameTypeWithoutEncoding(srcType, resultType)) {
+ // fast failure path when in and out types do not match
+ return false;
+ }
+ // both types are ensured to have the same rank
+ for (int64_t i = 0; i < resultType.getRank(); ++i) {
+ // check the ExtractSliceOp offsets, should be all-zero
+ if (sliceOp.isDynamicOffset(i) || sliceOp.getStaticOffset(i) != 0)
+ return false;
+ // check the ExtractSliceOp Strides, should be all-one
+ if (sliceOp.isDynamicStride(i) || sliceOp.getStaticStride(i) != 1)
+ return false;
+ }
+ // check if the dynamic shape matchs
+ if (resultType.getNumDynamicDims() != 0) {
+ for (int64_t i = 0; i < resultType.getRank(); ++i) {
+ if (resultType.isDynamicDim(i)) {
+ auto largeDim =
+ getMixedSize(rewriter, sliceOp.getLoc(), largerTensor, i);
+ auto smallDim = sliceOp.getDynamicSize(i);
+ if (largeDim.dyn_cast<Value>() != smallDim) {
+ return false;
+ }
+ }
+ }
+ }
+ // if the tensor is in static-shape, we already checked the shapes match via
+ // isSameTypeWithoutEncoding
+ return true;
+}
+
+struct EliminateWholeSliceExtractSliceOp
+ : public OpRewritePattern<ExtractSliceOp> {
+ EliminateWholeSliceExtractSliceOp(MLIRContext *ctx)
+ : OpRewritePattern<ExtractSliceOp>(ctx) {}
+
+ LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getResult(),
+ sliceOp.getSource())) {
+ return failure();
+ }
+ // all checking are done. Rewrite the IR
+ rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+ rewriter.eraseOp(sliceOp);
+ return success();
+ }
+};
+
+struct EliminateWholeSliceInsertSliceOp
+ : public OpRewritePattern<InsertSliceOp> {
+ EliminateWholeSliceInsertSliceOp(MLIRContext *ctx)
+ : OpRewritePattern<InsertSliceOp>(ctx) {}
+
+ LogicalResult matchAndRewrite(InsertSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getSource(),
+ sliceOp.getDest())) {
+ return failure();
+ }
+ // all checking are done. Rewrite the IR
+ rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+ rewriter.eraseOp(sliceOp);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateEliminateWholeSlicingPatterns(
+ RewritePatternSet &patterns) {
+ patterns
+ .add<EliminateWholeSliceExtractSliceOp, EliminateWholeSliceInsertSliceOp>(
+ patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
new file mode 100644
index 00000000000000..077d36c26d4816
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-eliminate-whole-slicing-patterns -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+//////////////////////////////
+// here starts the tests for insert_slice
+//////////////////////////////
+
+func.func @elim_dyn_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_dyn_insert
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+// CHECK: return %[[INSERT]]
+
+func.func @elim_static_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_static_insert
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+// CHECK: return %[[INSERT]]
+
+func.func @fail_dyn_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_dyn_insert_shape
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+func.func @fail_static_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %3 = tensor.empty() : tensor<14x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<14x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_static_insert_shape
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: tensor.empty()
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+func.func @fail_dyn_insert_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, %arg2] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to dynamic stride
+// CHECK-LABEL: func.func @fail_dyn_insert_stride
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_insert_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 1] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_insert_offset
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+//////////////////////////////
+// here starts the tests for extract_slice
+//////////////////////////////
+func.func @elim_dyn_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_dyn_extract
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [%[[OFFSET1]], 32, 32, 32]
+// CHECK: return %[[EXTRACT]]
+
+
+func.func @elim_static_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_static_extract
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [15, 32, 32, 32]
+// CHECK: return %[[EXTRACT]]
+
+// fail to optimize due to unmatched shape
+func.func @fail_dyn_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_dyn_extract_shape
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to unmatched shape
+func.func @fail_static_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<14x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<14x32x32x32xbf16>
+ return %extracted_slice2 : tensor<14x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_shape
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to stride
+func.func @fail_extract_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 3] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_stride
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_extract_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_offset
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+
+
+//////////////////////////////
+// here starts the tests for expanding/reducing dims
+//////////////////////////////
+func.func @fail_extract_reduce(%arg0: tensor<1x32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x32x32x32x32xbf16> to tensor<1x15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_reduce
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+func.func @fail_insert_expand(%arg0: tensor<1x15x32x32x32xbf16>, %arg1: tensor<1x15x32x32x32xbf16>, %arg2: index) -> tensor<1x15x32x32x32xbf16> {
+ %extracted_slice = tensor.empty(): tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.insert_slice %extracted_slice into %arg0[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<1x15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<1x15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_insert_expand
+// CHECK: tensor.empty
+// CHECK: tensor.insert_slice
+// CHECK: return
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 34de600132f5de..a4a91ffd3b7660 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -98,6 +98,12 @@ struct TestTensorTransforms
*this, "test-tracking-listener",
llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
llvm::cl::init(false)};
+
+ Option<bool> testEliminateWholeSlicingPatterns{
+ *this, "test-eliminate-whole-slicing-patterns",
+ llvm::cl::desc("Test patterns to eliminate whole-slicing extract_slice "
+ "and insert_slice"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -154,6 +160,12 @@ static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyEliminateWholeSlicingPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateEliminateWholeSlicingPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -406,6 +418,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
+ if (testEliminateWholeSlicingPatterns)
+ applyEliminateWholeSlicingPatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
>From 09276d8e961e2f2c589d0d3470dcaacf647fa7dc Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 3 Sep 2024 12:16:08 +0800
Subject: [PATCH 2/2] add doc
---
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 9b94a98bcde36a..4ae782661681f1 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -103,6 +103,8 @@ void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);
/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+/// The patterns remove extract_slice and insert_slice when the size matches
+/// and the offsets of the slice are all zeros and strides are all ones.
void populateEliminateWholeSlicingPatterns(
RewritePatternSet &patterns);
More information about the Mlir-commits
mailing list