[Mlir-commits] [mlir] 8544523 - [mlir][tensor] Promote extract(from_elements(...)) to folding pattern
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 20 07:48:12 PDT 2022
Author: Matthias Springer
Date: 2022-04-20T23:47:42+09:00
New Revision: 8544523dcb6249bf3055c3a6ab0cb48586999a30
URL: https://github.com/llvm/llvm-project/commit/8544523dcb6249bf3055c3a6ab0cb48586999a30
DIFF: https://github.com/llvm/llvm-project/commit/8544523dcb6249bf3055c3a6ab0cb48586999a30.diff
LOG: [mlir][tensor] Promote extract(from_elements(...)) to folding pattern
Differential Revision: https://reviews.llvm.org/D123617
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f1181e63ec8f2..1f9f97769646b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -361,16 +361,13 @@ LogicalResult ExtractOp::verify() {
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
- // The tensor operand must be a known constant.
- Attribute tensor = operands.front();
- if (!tensor)
- return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
- if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
- return splatTensor.getSplatValue<Attribute>();
+ if (Attribute tensor = operands.front())
+ if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
+ return splatTensor.getSplatValue<Attribute>();
- // Otherwise, collect the constant indices into the tensor.
+ // Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
@@ -378,10 +375,34 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
+ // Fold extract(from_elements(...)).
+ if (auto fromElementsOp = tensor().getDefiningOp<FromElementsOp>()) {
+ auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto rank = tensorType.getRank();
+ assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
+ "rank mismatch");
+ int flatIndex = 0;
+ int stride = 1;
+ for (int i = rank - 1; i >= 0; --i) {
+ if (i < rank - 1)
+ stride *= tensorType.getDimSize(i);
+ flatIndex += indices[i] * stride;
+ }
+ // Prevent out of bounds accesses. This can happen in invalid code that will
+ // never execute.
+ if (static_cast<int>(fromElementsOp.elements().size()) <= flatIndex ||
+ flatIndex < 0)
+ return {};
+ return fromElementsOp.elements()[flatIndex];
+ }
+
// If this is an elements attribute, query the value at the given indices.
- auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
- if (elementsAttr && elementsAttr.isValidIndex(indices))
- return elementsAttr.getValues<Attribute>()[indices];
+ if (Attribute tensor = operands.front()) {
+ auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
+ if (elementsAttr && elementsAttr.isValidIndex(indices))
+ return elementsAttr.getValues<Attribute>()[indices];
+ }
+
return {};
}
@@ -411,47 +432,6 @@ OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
namespace {
-// Canonicalizes the pattern of the form
-//
-// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
-// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
-//
-// to just %element.
-struct ExtractElementFromTensorFromElements
- : public OpRewritePattern<tensor::ExtractOp> {
- using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExtractOp extract,
- PatternRewriter &rewriter) const final {
- auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
- if (!tensorFromElements)
- return failure();
- auto tensorType = tensorFromElements.getType().cast<RankedTensorType>();
- auto rank = tensorType.getRank();
- if (rank == 0) {
- rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
- return success();
- }
- SmallVector<APInt, 3> indices(rank);
- int64_t flatIndex = 0;
- int64_t stride = 1;
- for (int i = rank - 1; i >= 0; --i) {
- APInt index;
- if (!matchPattern(extract.indices()[i], m_ConstantInt(&index)))
- return failure();
- if (i < rank - 1)
- stride *= tensorType.getDimSize(i);
- flatIndex += index.getSExtValue() * stride;
- }
- // Prevent out of bounds accesses. This can happen in invalid code that will
- // never execute.
- if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0)
- return failure();
- rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex));
- return success();
- }
-};
-
// Pushes the index_casts that occur before extractions to after the extract.
// This minimizes type conversion in some cases and enables the extract
// canonicalizer. This changes:
@@ -494,9 +474,7 @@ struct ExtractElementFromIndexCast
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ExtractElementFromIndexCast, ExtractElementFromTensorFromElements>(
- context);
+ results.add<ExtractElementFromIndexCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index ced1ca525546c..a3bd3d0d26796 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -22,12 +22,8 @@ func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
-// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
-// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
-// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
-// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
-// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
// -----
More information about the Mlir-commits
mailing list