[Mlir-commits] [mlir] c897a7f - [mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements
Stephan Herhut
llvmlistbot at llvm.org
Tue Sep 15 06:38:34 PDT 2020
Author: Stephan Herhut
Date: 2020-09-15T15:38:14+02:00
New Revision: c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7
URL: https://github.com/llvm/llvm-project/commit/c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7
DIFF: https://github.com/llvm/llvm-project/commit/c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7.diff
LOG: [mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements
This add canonicalizer for
- extracting an element from a dynamic_tensor_from_elements
- propagating constant operands to the type of dynamic_tensor_from_elements
Differential Revision: https://reviews.llvm.org/D87525
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 4d0cf76ec9d8..b0aa9b9e3c76 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1511,6 +1511,8 @@ def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
"ValueRange dynamicExtents, "
"function_ref<void(OpBuilder &, Location, ValueRange)>">,
];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c77bc12cca33..0c86c87384d3 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
@@ -1730,6 +1731,101 @@ void DynamicTensorFromElementsOp::build(
bodyBuilder(b, result.location, bodyBlock->getArguments());
}
+namespace {
+
+/// Canonicalizes dynamic_tensor_from_elements operations with a constant
+/// operand into the equivalent operation with the operand expressed in the
+/// result type, instead. We also insert a type cast to make sure that the
+/// resulting IR is still well-typed.
+struct StaticDynamicTensorFromElements
+ : public OpRewritePattern<DynamicTensorFromElementsOp> {
+ using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
+ PatternRewriter &rewriter) const final {
+ auto resultType =
+ tensorFromElements.getResult().getType().cast<RankedTensorType>();
+
+ if (resultType.hasStaticShape())
+ return failure();
+
+ SmallVector<Value, 4> newOperands;
+ SmallVector<int64_t, 4> newShape;
+ auto operandsIt = tensorFromElements.dynamicExtents().begin();
+
+ for (int64_t dim : resultType.getShape()) {
+ if (dim != RankedTensorType::kDynamicSize) {
+ newShape.push_back(dim);
+ continue;
+ }
+ APInt index;
+ if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
+ newShape.push_back(RankedTensorType::kDynamicSize);
+ newOperands.push_back(*operandsIt++);
+ continue;
+ }
+ newShape.push_back(index.getSExtValue());
+ operandsIt++;
+ }
+
+ if (newOperands.size() == tensorFromElements.dynamicExtents().size())
+ return failure();
+
+ auto loc = tensorFromElements.getLoc();
+ auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
+ loc, RankedTensorType::get(newShape, resultType.getElementType()),
+ newOperands);
+ rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
+ newOp.body().begin());
+ rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
+ newOp);
+ return success();
+ }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %tensor = dynamic_tensor_from_elements %x {
+/// ^bb0(%arg0: index): // no predecessors
+/// <computation>
+/// yield %1 : index
+/// } : tensor<?xindex>
+/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
+///
+/// to just <computation> with %arg0 replaced by %c0. We only do this if the
+/// dynamic_tensor_from_elements operation has no side-effects.
+struct ExtractElementFromDynamicTensorFromElements
+ : public OpRewritePattern<ExtractElementOp> {
+ using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractElementOp extract,
+ PatternRewriter &rewriter) const final {
+ auto tensorFromElements =
+ extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
+ if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
+ return failure();
+
+ BlockAndValueMapping mapping;
+ Block *body = tensorFromElements.getBody();
+ mapping.map(body->getArguments(), extract.indices());
+ for (auto &op : body->without_terminator())
+ rewriter.clone(op, mapping);
+
+ auto yield = cast<YieldOp>(body->getTerminator());
+
+ rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
+ return success();
+ }
+};
+
+} // namespace
+
+void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ExtractElementFromDynamicTensorFromElements,
+ StaticDynamicTensorFromElements>(context);
+}
+
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//
@@ -1807,16 +1903,16 @@ struct ExtractElementFromTensorFromElements
if (extract.indices().size() != 1)
return failure();
- auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
+ auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
extract.aggregate().getDefiningOp());
- if (tensor_from_elements == nullptr)
+ if (tensorFromElements == nullptr)
return failure();
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
return failure();
rewriter.replaceOp(extract,
- tensor_from_elements.getOperand(index.getZExtValue()));
+ tensorFromElements.getOperand(index.getZExtValue()));
return success();
}
};
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 76fe82588be3..320418545893 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -986,3 +986,79 @@ func @extract_element_from_tensor_from_elements(%element : index) -> index {
// CHECK: [[ARG]] : index
return %extracted_element : index
}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
+// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
+ %0 = dynamic_tensor_from_elements %size {
+ ^bb0(%arg0: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ yield %1 : index
+ } : tensor<?xindex>
+ %1 = extract_element %0[%idx] : tensor<?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
+ // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
+ // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
+ %0 = dynamic_tensor_from_elements %size, %size {
+ ^bb0(%arg0: index, %arg1: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ %2 = dim %tensor, %arg1 : tensor<*xf32>
+ %3 = addi %1, %2 : index
+ yield %3 : index
+ } : tensor<?x?xindex>
+ %4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %4 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
+// CHECK-SAME: %[[IDX:.*]]: index
+func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
+ %size = rank %tensor : tensor<*xf32>
+ %mem = alloc(%size) : memref<?xindex>
+ // CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
+ %0 = dynamic_tensor_from_elements %size {
+ ^bb0(%arg0: index):
+ %1 = dim %tensor, %arg0 : tensor<*xf32>
+ store %1, %mem[%arg0] : memref<?xindex>
+ yield %1 : index
+ } : tensor<?xindex>
+ // CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
+ %1 = extract_element %0[%idx] : tensor<?xindex>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @static_dynamic_tensor_from_elements
+// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
+func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
+ %c5 = constant 5 : index
+ // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]]
+ %0 = dynamic_tensor_from_elements %size1, %c5, %size4 {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ %1 = constant 32 : index
+ yield %1 : index
+ // CHECK: : tensor<3x?x5x7x?xindex>
+ } : tensor<3x?x?x7x?xindex>
+ // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
+ return %0 : tensor<3x?x?x7x?xindex>
+}
+
More information about the Mlir-commits
mailing list