[Mlir-commits] [mlir] e34b883 - [MLIR][Shape] Lower `shape_of` for unranked tensors
Frederik Gossen
llvmlistbot at llvm.org
Thu Jun 25 01:51:57 PDT 2020
Author: Frederik Gossen
Date: 2020-06-25T08:50:45Z
New Revision: e34b88309e7dfd86d298d01b7ab36e5858f1b62f
URL: https://github.com/llvm/llvm-project/commit/e34b88309e7dfd86d298d01b7ab36e5858f1b62f
DIFF: https://github.com/llvm/llvm-project/commit/e34b88309e7dfd86d298d01b7ab36e5858f1b62f.diff
LOG: [MLIR][Shape] Lower `shape_of` for unranked tensors
Lower `shape_of` for unranked tensors.
Materializes shape in stack-allocated memory.
Differential Revision: https://reviews.llvm.org/D82196
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d61c8af99a68..8440b9b3d60b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1408,7 +1408,9 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
- "Value memrefOrTensor, int64_t index">
+ "Value memrefOrTensor, int64_t index">,
+ OpBuilder<"OpBuilder &builder, OperationState &result, "
+ "Value memrefOrTensor, Value index">
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index db7796d5c6a0..adf046ef5678 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -69,6 +69,58 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
return success();
}
+namespace {
+/// Converts `shape_of` to for loop for unranked tensors.
+class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
+public:
+ using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ ShapeOfOp::Adaptor transformed(operands);
+ auto tensorVal = transformed.arg();
+ auto tensorTy = tensorVal.getType();
+
+ // For ranked tensors `shape_of` lowers to `std` and the pattern can be
+ // found in the corresponding pass.
+ if (tensorTy.isa<RankedTensorType>())
+ return failure();
+
+ // Allocate stack memory.
+ auto loc = op.getLoc();
+ auto rankVal = rewriter.create<RankOp>(loc, tensorVal);
+ auto i64Ty = rewriter.getI64Type();
+ auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+ auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
+
+ // Copy shape extents to stack-allocated memory.
+ auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
+ auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
+ rewriter.create<scf::ForOp>(
+ loc, zeroVal, rankVal, oneVal, ValueRange(),
+ [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
+ auto dimVal = b.create<DimOp>(loc, tensorVal, iVal);
+ auto dimIntVal = b.create<IndexCastOp>(loc, dimVal, i64Ty);
+ b.create<StoreOp>(loc, dimIntVal, memVal, ValueRange({iVal}));
+ b.create<scf::YieldOp>(loc);
+ });
+
+ // Load extents to tensor value.
+ auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
+ auto indexTy = rewriter.getIndexType();
+ auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+ rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
+ shapeTy);
+ return success();
+}
+
namespace {
struct ConvertShapeToSCFPass
: public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
@@ -79,19 +131,23 @@ struct ConvertShapeToSCFPass
void ConvertShapeToSCFPass::runOnFunction() {
MLIRContext &ctx = getContext();
+ // Populate conversion patterns.
OwningRewritePatternList patterns;
populateShapeToSCFConversionPatterns(patterns, &ctx);
+ // Setup target legality.
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
- target.addIllegalOp<ReduceOp>();
- if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+ target.addIllegalOp<ReduceOp, ShapeOfOp>();
+
+ // Apply conversion.
+ if (failed(applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
void mlir::populateShapeToSCFConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<ReduceOpConverter>(ctx);
+ patterns.insert<ReduceOpConverter, ShapeOfOpConverter>(ctx);
}
std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ca4fe836495b..6e6ad47d3141 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1273,8 +1273,13 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
Value memrefOrTensor, int64_t index) {
auto loc = result.location;
Value indexValue = builder.create<ConstantIndexOp>(loc, index);
+ build(builder, result, memrefOrTensor, indexValue);
+}
+
+void DimOp::build(OpBuilder &builder, OperationState &result,
+ Value memrefOrTensor, Value index) {
auto indexTy = builder.getIndexType();
- build(builder, result, indexTy, memrefOrTensor, indexValue);
+ build(builder, result, indexTy, memrefOrTensor, index);
}
Optional<int64_t> DimOp::getConstantIndex() {
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index b52266ce82ec..1c214567c63a 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -26,3 +26,25 @@ func @shape_reduce(%shape : !shape.shape) -> !shape.size {
// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
// CHECK-NEXT: }
// CHECK-NEXT: return [[RESULT]] : !shape.size
+
+// -----
+
+// Lower `shape_of` for unranked tensors.
+// CHECK-LABEL: @shape_of_unranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of_unranked(%arg : tensor<*xf32>) {
+ // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+ // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+ // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+ // CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+ // CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
+ // CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+ // CHECK: }
+ // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
+ // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+ %shape = shape.shape_of %arg : tensor<*xf32>
+ return
+}
+
More information about the Mlir-commits
mailing list