[Mlir-commits] [mlir] 7b52aea - [mlir][Tensor] Add folding for tensor.from_elements

Benjamin Kramer llvmlistbot at llvm.org
Mon May 10 15:42:55 PDT 2021


Author: Benjamin Kramer
Date: 2021-05-11T00:42:45+02:00
New Revision: 7b52aeadfa38c8a1fc0e97066f50900f1efafd42

URL: https://github.com/llvm/llvm-project/commit/7b52aeadfa38c8a1fc0e97066f50900f1efafd42
DIFF: https://github.com/llvm/llvm-project/commit/7b52aeadfa38c8a1fc0e97066f50900f1efafd42.diff

LOG: [mlir][Tensor] Add folding for tensor.from_elements

This trivially folds into a constant when all operands are constant.

Differential Revision: https://reviews.llvm.org/D102199

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Linalg/detensorize_trivial.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a0e473873d27a..17141da1b3e88 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -137,6 +137,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
   ];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1beb458df1f53..2c9680adbf1b4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -238,6 +238,12 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
   build(builder, result, elements.front().getType(), elements);
 }
 
+OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
+  if (!llvm::is_contained(operands, nullptr))
+    return DenseElementsAttr::get(getType(), operands);
+  return {};
+}
+
 namespace {
 
 // Canonicalizes the pattern of the form

diff  --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
index 6fcd056f9b365..4e0b8fdd00468 100644
--- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -35,9 +35,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
 // DET-ALL-NEXT:  }
 
 // DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
-// DET-CF-NEXT:    constant 10 : i32
-// DET-CF-NEXT:    tensor.from_elements %{{.*}}
-// DET-CF-NEXT:    linalg.tensor_reshape %{{.*}}
+// DET-CF-NEXT:    constant dense<10> : tensor<i32>
 // DET-CF-NEXT:    linalg.init_tensor [] : tensor<i1>
 // DET-CF-NEXT:    linalg.generic
 // DET-CF-NEXT:    ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index be22f323873e1..478117b325c94 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -238,3 +238,15 @@ func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xi
   // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
   return %0 : tensor<3x?x?x7x?xindex>
 }
+
+// -----
+
+// CHECK-LABEL: @from_elements.constant
+func @from_elements.constant() -> tensor<3xindex> {
+  // CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex>
+  // CHECK: return %[[CST]]
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
+  return %tensor : tensor<3xindex>
+}


        


More information about the Mlir-commits mailing list