[Mlir-commits] [mlir] 7b52aea - [mlir][Tensor] Add folding for tensor.from_elements
Benjamin Kramer
llvmlistbot at llvm.org
Mon May 10 15:42:55 PDT 2021
Author: Benjamin Kramer
Date: 2021-05-11T00:42:45+02:00
New Revision: 7b52aeadfa38c8a1fc0e97066f50900f1efafd42
URL: https://github.com/llvm/llvm-project/commit/7b52aeadfa38c8a1fc0e97066f50900f1efafd42
DIFF: https://github.com/llvm/llvm-project/commit/7b52aeadfa38c8a1fc0e97066f50900f1efafd42.diff
LOG: [mlir][Tensor] Add folding for tensor.from_elements
This trivially folds into a constant when all operands are constant.
Differential Revision: https://reviews.llvm.org/D102199
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Linalg/detensorize_trivial.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a0e473873d27a..17141da1b3e88 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -137,6 +137,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1beb458df1f53..2c9680adbf1b4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -238,6 +238,12 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, elements.front().getType(), elements);
}
+OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
+ if (!llvm::is_contained(operands, nullptr))
+ return DenseElementsAttr::get(getType(), operands);
+ return {};
+}
+
namespace {
// Canonicalizes the pattern of the form
diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
index 6fcd056f9b365..4e0b8fdd00468 100644
--- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -35,9 +35,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
// DET-ALL-NEXT: }
// DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
-// DET-CF-NEXT: constant 10 : i32
-// DET-CF-NEXT: tensor.from_elements %{{.*}}
-// DET-CF-NEXT: linalg.tensor_reshape %{{.*}}
+// DET-CF-NEXT: constant dense<10> : tensor<i32>
// DET-CF-NEXT: linalg.init_tensor [] : tensor<i1>
// DET-CF-NEXT: linalg.generic
// DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index be22f323873e1..478117b325c94 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -238,3 +238,15 @@ func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xi
// CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
return %0 : tensor<3x?x?x7x?xindex>
}
+
+// -----
+
+// CHECK-LABEL: @from_elements.constant
+func @from_elements.constant() -> tensor<3xindex> {
+ // CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex>
+ // CHECK: return %[[CST]]
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
+ return %tensor : tensor<3xindex>
+}
More information about the Mlir-commits
mailing list