[Mlir-commits] [mlir] [mlir] [tensor] Add patterns to remove whole slicing of tensors (PR #107046)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 2 21:20:16 PDT 2024


https://github.com/Menooker created https://github.com/llvm/llvm-project/pull/107046

Eliminate the redundant `tensor.extract_slice` and `tensor.insert_slice` when the slice size is proved to be the same as the source tensor. Dynamic shapes are also supported.

Examples of the extract/insert to be removed:
```
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
```

>From 92d7879dc2b2b1666f0faf2bb7b5f43b411e3760 Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 3 Sep 2024 12:09:48 +0800
Subject: [PATCH 1/2] [mlir] [tensor] Add pattern to remove whole slicing of
 tensors

---
 .../Dialect/Tensor/Transforms/Transforms.h    |   4 +
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   1 +
 .../EliminateWholeSlicePatterns.cpp           |  98 +++++++++
 .../Tensor/eliminate-whole-slicing.mlir       | 194 ++++++++++++++++++
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  14 ++
 5 files changed, 311 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
 create mode 100644 mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..9b94a98bcde36a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,10 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
 void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
                                        const ControlFoldFn &controlFn);
 
+/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+void populateEliminateWholeSlicingPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Transform helpers
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index ce32dea09bb0b5..d5bbedd13e7acc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
   SubsetInsertionOpInterfaceImpl.cpp
+  EliminateWholeSlicePatterns.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
new file mode 100644
index 00000000000000..52ca6a9e6f65f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
@@ -0,0 +1,98 @@
+//===- EliminateWholeSlicePatterns.cpp - Patterns to remove whole slices --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+bool checkEliminateOK(PatternRewriter &rewriter,
+                      OffsetSizeAndStrideOpInterface sliceOp,
+                      mlir::TypedValue<mlir::RankedTensorType> smallerTensor,
+                      mlir::TypedValue<mlir::RankedTensorType> largerTensor) {
+  auto srcType = largerTensor.getType();
+  auto resultType = smallerTensor.getType();
+  if (!isSameTypeWithoutEncoding(srcType, resultType)) {
+    // fast failure path when in and out types do not match
+    return false;
+  }
+  // both types are ensured to have the same rank
+  for (int64_t i = 0; i < resultType.getRank(); ++i) {
+    // check the ExtractSliceOp offsets, should be all-zero
+    if (sliceOp.isDynamicOffset(i) || sliceOp.getStaticOffset(i) != 0)
+      return false;
+    // check the ExtractSliceOp Strides, should be all-one
+    if (sliceOp.isDynamicStride(i) || sliceOp.getStaticStride(i) != 1)
+      return false;
+  }
+  // check if the dynamic shape matchs
+  if (resultType.getNumDynamicDims() != 0) {
+    for (int64_t i = 0; i < resultType.getRank(); ++i) {
+      if (resultType.isDynamicDim(i)) {
+        auto largeDim =
+            getMixedSize(rewriter, sliceOp.getLoc(), largerTensor, i);
+        auto smallDim = sliceOp.getDynamicSize(i);
+        if (largeDim.dyn_cast<Value>() != smallDim) {
+          return false;
+        }
+      }
+    }
+  }
+  // if the tensor is in static-shape, we already checked the shapes match via
+  // isSameTypeWithoutEncoding
+  return true;
+}
+
+struct EliminateWholeSliceExtractSliceOp
+    : public OpRewritePattern<ExtractSliceOp> {
+  EliminateWholeSliceExtractSliceOp(MLIRContext *ctx)
+      : OpRewritePattern<ExtractSliceOp>(ctx) {}
+
+  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getResult(),
+                          sliceOp.getSource())) {
+      return failure();
+    }
+    // all checking are done. Rewrite the IR
+    rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+    rewriter.eraseOp(sliceOp);
+    return success();
+  }
+};
+
+struct EliminateWholeSliceInsertSliceOp
+    : public OpRewritePattern<InsertSliceOp> {
+  EliminateWholeSliceInsertSliceOp(MLIRContext *ctx)
+      : OpRewritePattern<InsertSliceOp>(ctx) {}
+
+  LogicalResult matchAndRewrite(InsertSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getSource(),
+                          sliceOp.getDest())) {
+      return failure();
+    }
+    // all checking are done. Rewrite the IR
+    rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+    rewriter.eraseOp(sliceOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateEliminateWholeSlicingPatterns(
+    RewritePatternSet &patterns) {
+  patterns
+      .add<EliminateWholeSliceExtractSliceOp, EliminateWholeSliceInsertSliceOp>(
+          patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
new file mode 100644
index 00000000000000..077d36c26d4816
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-eliminate-whole-slicing-patterns -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+//////////////////////////////
+// here starts the tests for insert_slice
+//////////////////////////////
+
+func.func @elim_dyn_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_dyn_insert
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+//       CHECK:   return %[[INSERT]]
+
+func.func @elim_static_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_static_insert
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_dyn_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_dyn_insert_shape
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_static_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %3 = tensor.empty() : tensor<14x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<14x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_static_insert_shape
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   tensor.empty()
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_dyn_insert_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, %arg2] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to dynamic stride
+// CHECK-LABEL: func.func @fail_dyn_insert_stride
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_insert_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 1] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_insert_offset
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+//////////////////////////////
+// here starts the tests for extract_slice
+//////////////////////////////
+func.func @elim_dyn_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_dyn_extract
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [%[[OFFSET1]], 32, 32, 32]
+//       CHECK:   return %[[EXTRACT]]
+
+
+func.func @elim_static_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_static_extract
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [15, 32, 32, 32]
+//       CHECK:   return %[[EXTRACT]]
+
+// fail to optimize due to unmatched shape
+func.func @fail_dyn_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_dyn_extract_shape
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to unmatched shape
+func.func @fail_static_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<14x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<14x32x32x32xbf16>
+  return %extracted_slice2 : tensor<14x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_shape
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to stride
+func.func @fail_extract_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 3] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_stride
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_extract_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_offset
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+
+
+//////////////////////////////
+// here starts the tests for expanding/reducing dims
+//////////////////////////////
+func.func @fail_extract_reduce(%arg0: tensor<1x32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x32x32x32x32xbf16> to tensor<1x15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_reduce
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+func.func @fail_insert_expand(%arg0: tensor<1x15x32x32x32xbf16>, %arg1: tensor<1x15x32x32x32xbf16>, %arg2: index) -> tensor<1x15x32x32x32xbf16> {
+  %extracted_slice = tensor.empty(): tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.insert_slice %extracted_slice into %arg0[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<1x15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<1x15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_insert_expand
+//       CHECK:   tensor.empty
+//       CHECK:   tensor.insert_slice
+//       CHECK:   return
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 34de600132f5de..a4a91ffd3b7660 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -98,6 +98,12 @@ struct TestTensorTransforms
       *this, "test-tracking-listener",
       llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
       llvm::cl::init(false)};
+
+  Option<bool> testEliminateWholeSlicingPatterns{
+      *this, "test-eliminate-whole-slicing-patterns",
+      llvm::cl::desc("Test patterns to eliminate whole-slicing extract_slice "
+                     "and insert_slice"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -154,6 +160,12 @@ static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyEliminateWholeSlicingPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateEliminateWholeSlicingPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -406,6 +418,8 @@ void TestTensorTransforms::runOnOperation() {
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
       return signalPassFailure();
   }
+  if (testEliminateWholeSlicingPatterns)
+    applyEliminateWholeSlicingPatterns(rootOp);
   if (testTrackingListener)
     if (failed(testTrackingListenerReplacements(rootOp)))
       return signalPassFailure();

>From 09276d8e961e2f2c589d0d3470dcaacf647fa7dc Mon Sep 17 00:00:00 2001
From: "Mei, Yijie" <yijie.mei at intel.com>
Date: Tue, 3 Sep 2024 12:16:08 +0800
Subject: [PATCH 2/2] add doc

---
 mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 9b94a98bcde36a..4ae782661681f1 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -103,6 +103,8 @@ void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
                                        const ControlFoldFn &controlFn);
 
 /// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+/// The patterns remove extract_slice and insert_slice when the size matches
+/// and the offsets of the slice are all zeros and strides are all ones.
 void populateEliminateWholeSlicingPatterns(
     RewritePatternSet &patterns);
 



More information about the Mlir-commits mailing list