[Mlir-commits] [mlir] e34b883 - [MLIR][Shape] Lower `shape_of` for unranked tensors

Frederik Gossen llvmlistbot at llvm.org
Thu Jun 25 01:51:57 PDT 2020


Author: Frederik Gossen
Date: 2020-06-25T08:50:45Z
New Revision: e34b88309e7dfd86d298d01b7ab36e5858f1b62f

URL: https://github.com/llvm/llvm-project/commit/e34b88309e7dfd86d298d01b7ab36e5858f1b62f
DIFF: https://github.com/llvm/llvm-project/commit/e34b88309e7dfd86d298d01b7ab36e5858f1b62f.diff

LOG: [MLIR][Shape] Lower `shape_of` for unranked tensors

Lower `shape_of` for unranked tensors.
Materializes shape in stack-allocated memory.

Differential Revision: https://reviews.llvm.org/D82196

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index d61c8af99a68..8440b9b3d60b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1408,7 +1408,9 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
 
   let builders = [
     OpBuilder<"OpBuilder &builder, OperationState &result, "
-              "Value memrefOrTensor, int64_t index">
+              "Value memrefOrTensor, int64_t index">,
+    OpBuilder<"OpBuilder &builder, OperationState &result, "
+              "Value memrefOrTensor, Value index">
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index db7796d5c6a0..adf046ef5678 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -69,6 +69,58 @@ ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
   return success();
 }
 
+namespace {
+/// Converts `shape_of` to for loop for unranked tensors.
+class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
+public:
+  using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  ShapeOfOp::Adaptor transformed(operands);
+  auto tensorVal = transformed.arg();
+  auto tensorTy = tensorVal.getType();
+
+  // For ranked tensors `shape_of` lowers to `std` and the pattern can be
+  // found in the corresponding pass.
+  if (tensorTy.isa<RankedTensorType>())
+    return failure();
+
+  // Allocate stack memory.
+  auto loc = op.getLoc();
+  auto rankVal = rewriter.create<RankOp>(loc, tensorVal);
+  auto i64Ty = rewriter.getI64Type();
+  auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty);
+  auto memVal = rewriter.create<AllocaOp>(loc, memTy, ValueRange({rankVal}));
+
+  // Copy shape extents to stack-allocated memory.
+  auto zeroVal = rewriter.create<ConstantIndexOp>(loc, 0);
+  auto oneVal = rewriter.create<ConstantIndexOp>(loc, 1);
+  rewriter.create<scf::ForOp>(
+      loc, zeroVal, rankVal, oneVal, ValueRange(),
+      [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) {
+        auto dimVal = b.create<DimOp>(loc, tensorVal, iVal);
+        auto dimIntVal = b.create<IndexCastOp>(loc, dimVal, i64Ty);
+        b.create<StoreOp>(loc, dimIntVal, memVal, ValueRange({iVal}));
+        b.create<scf::YieldOp>(loc);
+      });
+
+  // Load extents to tensor value.
+  auto shapeIntVal = rewriter.create<TensorLoadOp>(loc, memVal);
+  auto indexTy = rewriter.getIndexType();
+  auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  rewriter.replaceOpWithNewOp<IndexCastOp>(op.getOperation(), shapeIntVal,
+                                           shapeTy);
+  return success();
+}
+
 namespace {
 struct ConvertShapeToSCFPass
     : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
@@ -79,19 +131,23 @@ struct ConvertShapeToSCFPass
 void ConvertShapeToSCFPass::runOnFunction() {
   MLIRContext &ctx = getContext();
 
+  // Populate conversion patterns.
   OwningRewritePatternList patterns;
   populateShapeToSCFConversionPatterns(patterns, &ctx);
 
+  // Setup target legality.
   ConversionTarget target(getContext());
   target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
-  target.addIllegalOp<ReduceOp>();
-  if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+  target.addIllegalOp<ReduceOp, ShapeOfOp>();
+
+  // Apply conversion.
+  if (failed(applyPartialConversion(getFunction(), target, patterns)))
     signalPassFailure();
 }
 
 void mlir::populateShapeToSCFConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<ReduceOpConverter>(ctx);
+  patterns.insert<ReduceOpConverter, ShapeOfOpConverter>(ctx);
 }
 
 std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ca4fe836495b..6e6ad47d3141 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1273,8 +1273,13 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
                   Value memrefOrTensor, int64_t index) {
   auto loc = result.location;
   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
+  build(builder, result, memrefOrTensor, indexValue);
+}
+
+void DimOp::build(OpBuilder &builder, OperationState &result,
+                  Value memrefOrTensor, Value index) {
   auto indexTy = builder.getIndexType();
-  build(builder, result, indexTy, memrefOrTensor, indexValue);
+  build(builder, result, indexTy, memrefOrTensor, index);
 }
 
 Optional<int64_t> DimOp::getConstantIndex() {

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index b52266ce82ec..1c214567c63a 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -26,3 +26,25 @@ func @shape_reduce(%shape : !shape.shape) -> !shape.size {
 // CHECK-NEXT:   scf.yield [[NEW_ACC]] : !shape.size
 // CHECK-NEXT: }
 // CHECK-NEXT: return [[RESULT]] : !shape.size
+
+// -----
+
+// Lower `shape_of` for unranked tensors.
+// CHECK-LABEL: @shape_of_unranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of_unranked(%arg : tensor<*xf32>) {
+  // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+  // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+  // CHECK-DAG:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+  // CHECK-DAG:   %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
+  // CHECK-DAG:   store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+  // CHECK:     }
+  // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
+  // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+  %shape = shape.shape_of %arg : tensor<*xf32>
+  return
+}
+


        


More information about the Mlir-commits mailing list