[Mlir-commits] [mlir] 42c195f - [mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor
Benjamin Kramer
llvmlistbot at llvm.org
Mon Mar 8 07:48:30 PST 2021
Author: Benjamin Kramer
Date: 2021-03-08T16:48:05+01:00
New Revision: 42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c
URL: https://github.com/llvm/llvm-project/commit/42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c
DIFF: https://github.com/llvm/llvm-project/commit/42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c.diff
LOG: [mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor
split_at can return an error if the split index is out of bounds. If the
user knows that the index can never be out of bounds it's safe to use
extent tensors. This has a straight-forward lowering to std.subtensor.
Differential Revision: https://reviews.llvm.org/D98177
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 27e219dc3129..a176e6d87673 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -604,7 +604,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
If `index` is negative, it is treated as indexing from the back of the
shape. This negative-handling behavior is important when handling unranked
shapes, where the positive index is not necessarily knowable due to a
- dynamic number of leading dimensions.
+ dynamic number of leading dimensions. If the result is in extent tensor form
+ out of bounds indices result in undefined behavior.
Examples:
- split_at([4,5,6], index=0) -> [], [4,5,6]
@@ -623,7 +624,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
let arguments = (ins Shape_ShapeOrExtentTensorType:$operand,
Shape_SizeOrIndexType:$index);
- let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+ let results = (outs Shape_ShapeOrExtentTensorType:$head,
+ Shape_ShapeOrExtentTensorType:$tail);
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 2b5d619bf58e..49c44ad78e8f 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -590,6 +590,47 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
return success();
}
+namespace {
+class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
+public:
+ using OpConversionPattern<SplitAtOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult SplitAtOpConversion::matchAndRewrite(
+ SplitAtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // Error conditions are not implemented, only lower if all operands and
+ // results are extent tensors.
+ if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
+ [](Value v) { return v.getType().isa<ShapeType>(); }))
+ return failure();
+
+ SplitAtOp::Adaptor transformed(op);
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value zero = b.create<ConstantIndexOp>(0);
+ Value rank = b.create<DimOp>(transformed.operand(), zero);
+
+ // index < 0 ? index + rank : index
+ Value originalIndex = transformed.index();
+ Value add = b.create<AddIOp>(originalIndex, rank);
+ Value indexIsNegative =
+ b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
+ Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
+
+ Value one = b.create<ConstantIndexOp>(1);
+ Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
+ Value tailSize = b.create<SubIOp>(rank, index);
+ Value tail =
+ b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
+ rewriter.replaceOp(op, {head, tail});
+ return success();
+}
+
namespace {
class ToExtentTensorOpConversion
: public OpConversionPattern<ToExtentTensorOp> {
@@ -660,6 +701,7 @@ void mlir::populateShapeToStandardConversionPatterns(
ReduceOpConverter,
ShapeEqOpConverter,
ShapeOfOpConversion,
+ SplitAtOpConversion,
ToExtentTensorOpConversion>(ctx);
// clang-format on
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d8aec027a11e..a4a0f7ece9b4 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -592,3 +592,23 @@ func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
: tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
return
}
+
+// -----
+
+// Lower `split_at`
+// CHECK-LABEL: @split_at
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index
+func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) {
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+ // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+ // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index
+ // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index
+ // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
+ // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+ // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index
+ // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
+ %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+ return %head, %tail : tensor<?xindex>, tensor<?xindex>
+}
More information about the Mlir-commits
mailing list