[Mlir-commits] [mlir] 42c195f - [mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor

Benjamin Kramer llvmlistbot at llvm.org
Mon Mar 8 07:48:30 PST 2021


Author: Benjamin Kramer
Date: 2021-03-08T16:48:05+01:00
New Revision: 42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c

URL: https://github.com/llvm/llvm-project/commit/42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c
DIFF: https://github.com/llvm/llvm-project/commit/42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c.diff

LOG: [mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor

split_at can return an error if the split index is out of bounds. If the
user knows that the index can never be out of bounds it's safe to use
extent tensors. This has a straight-forward lowering to std.subtensor.

Differential Revision: https://reviews.llvm.org/D98177

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 27e219dc3129..a176e6d87673 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -604,7 +604,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
     If `index` is negative, it is treated as indexing from the back of the
     shape. This negative-handling behavior is important when handling unranked
     shapes, where the positive index is not necessarily knowable due to a
-    dynamic number of leading dimensions.
+    dynamic number of leading dimensions. If the result is in extent tensor form
+    out of bounds indices result in undefined behavior.
 
     Examples:
     - split_at([4,5,6], index=0) -> [], [4,5,6]
@@ -623,7 +624,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
 
   let arguments = (ins Shape_ShapeOrExtentTensorType:$operand,
                        Shape_SizeOrIndexType:$index);
-  let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+  let results = (outs Shape_ShapeOrExtentTensorType:$head,
+                      Shape_ShapeOrExtentTensorType:$tail);
   let hasFolder = 1;
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 2b5d619bf58e..49c44ad78e8f 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -590,6 +590,47 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
   return success();
 }
 
+namespace {
+class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
+public:
+  using OpConversionPattern<SplitAtOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult SplitAtOpConversion::matchAndRewrite(
+    SplitAtOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  // Error conditions are not implemented, only lower if all operands and
+  // results are extent tensors.
+  if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
+                   [](Value v) { return v.getType().isa<ShapeType>(); }))
+    return failure();
+
+  SplitAtOp::Adaptor transformed(op);
+  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+  Value zero = b.create<ConstantIndexOp>(0);
+  Value rank = b.create<DimOp>(transformed.operand(), zero);
+
+  // index < 0 ? index + rank : index
+  Value originalIndex = transformed.index();
+  Value add = b.create<AddIOp>(originalIndex, rank);
+  Value indexIsNegative =
+      b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
+  Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
+
+  Value one = b.create<ConstantIndexOp>(1);
+  Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
+  Value tailSize = b.create<SubIOp>(rank, index);
+  Value tail =
+      b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
+  rewriter.replaceOp(op, {head, tail});
+  return success();
+}
+
 namespace {
 class ToExtentTensorOpConversion
     : public OpConversionPattern<ToExtentTensorOp> {
@@ -660,6 +701,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       ReduceOpConverter,
       ShapeEqOpConverter,
       ShapeOfOpConversion,
+      SplitAtOpConversion,
       ToExtentTensorOpConversion>(ctx);
   // clang-format on
 }

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d8aec027a11e..a4a0f7ece9b4 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -592,3 +592,23 @@ func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
       : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
   return
 }
+
+// -----
+
+// Lower `split_at`
+// CHECK-LABEL: @split_at
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index
+func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) {
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+  // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index
+  // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index
+  // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
+  // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+  // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index
+  // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
+  %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+  return %head, %tail : tensor<?xindex>, tensor<?xindex>
+}


        


More information about the Mlir-commits mailing list