[Mlir-commits] [mlir] 88df30c - [mlir] Add canonicalization for extract(tensor.from_elements) in 0d case.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Dec 16 06:54:42 PST 2021
Author: Alexander Belyaev
Date: 2021-12-16T15:46:57+01:00
New Revision: 88df30c8d81d3f0651fb8ad4e86b08eb2898744c
URL: https://github.com/llvm/llvm-project/commit/88df30c8d81d3f0651fb8ad4e86b08eb2898744c
DIFF: https://github.com/llvm/llvm-project/commit/88df30c8d81d3f0651fb8ad4e86b08eb2898744c.diff
LOG: [mlir] Add canonicalization for extract(tensor.from_elements) in 0d case.
Differential Revision: https://reviews.llvm.org/D115875
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 37906adb2918..49db8688c0da 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -402,6 +402,10 @@ struct ExtractElementFromTensorFromElements
return failure();
auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
auto rank = tensorType.getRank();
+ if (rank == 0) {
+ rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
+ return success();
+ }
SmallVector<APInt, 3> indices(rank);
int64_t flatIndex = 0;
int64_t stride = 1;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 5331e50790a6..2c18fe4a6d5e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -135,6 +135,18 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
+// CHECK-LABEL: func @extract_from_tensor.from_elements_0d
+func @extract_from_tensor.from_elements_0d(%element : index) -> index {
+ // CHECK-SAME: ([[ARG:%.*]]: index)
+ %c0 = arith.constant 0 : index
+ %tensor = tensor.from_elements %element : tensor<index>
+ %extracted_element = tensor.extract %tensor[] : tensor<index>
+ // CHECK: [[ARG]] : index
+ return %extracted_element : index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_from_tensor.from_elements_3d
func @extract_from_tensor.from_elements_3d()
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
More information about the Mlir-commits
mailing list