[Mlir-commits] [mlir] 88df30c - [mlir] Add canonicalization for extract(tensor.from_elements) in 0d case.

Alexander Belyaev llvmlistbot at llvm.org
Thu Dec 16 06:54:42 PST 2021


Author: Alexander Belyaev
Date: 2021-12-16T15:46:57+01:00
New Revision: 88df30c8d81d3f0651fb8ad4e86b08eb2898744c

URL: https://github.com/llvm/llvm-project/commit/88df30c8d81d3f0651fb8ad4e86b08eb2898744c
DIFF: https://github.com/llvm/llvm-project/commit/88df30c8d81d3f0651fb8ad4e86b08eb2898744c.diff

LOG: [mlir] Add canonicalization for extract(tensor.from_elements) in 0d case.

Differential Revision: https://reviews.llvm.org/D115875

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 37906adb2918..49db8688c0da 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -402,6 +402,10 @@ struct ExtractElementFromTensorFromElements
       return failure();
     auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
     auto rank = tensorType.getRank();
+    if (rank == 0) {
+      rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
+      return success();
+    }
     SmallVector<APInt, 3> indices(rank);
     int64_t flatIndex = 0;
     int64_t stride = 1;

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 5331e50790a6..2c18fe4a6d5e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -135,6 +135,18 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_from_tensor.from_elements_0d
+func @extract_from_tensor.from_elements_0d(%element : index) -> index {
+  // CHECK-SAME: ([[ARG:%.*]]: index)
+  %c0 = arith.constant 0 : index
+  %tensor = tensor.from_elements %element : tensor<index>
+  %extracted_element = tensor.extract %tensor[] : tensor<index>
+  // CHECK: [[ARG]] : index
+  return %extracted_element : index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_from_tensor.from_elements_3d
 func @extract_from_tensor.from_elements_3d()
     -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {


        


More information about the Mlir-commits mailing list