[Mlir-commits] [mlir] dfcc098 - [MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`
Frederik Gossen
llvmlistbot at llvm.org
Tue Jul 28 08:41:24 PDT 2020
Author: Frederik Gossen
Date: 2020-07-28T15:40:55Z
New Revision: dfcc09890a91b1085139fee175936b0e67824e47
URL: https://github.com/llvm/llvm-project/commit/dfcc09890a91b1085139fee175936b0e67824e47
DIFF: https://github.com/llvm/llvm-project/commit/dfcc09890a91b1085139fee175936b0e67824e47.diff
LOG: [MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`
Differential Revision: https://reviews.llvm.org/D82848
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f239d1cfb4f0..b84b6ba3b5d6 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -103,6 +103,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
return success();
}
+namespace {
+class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
+public:
+ using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ConstShapeOpConverter::matchAndRewrite(
+ ConstShapeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+
+ // For now, this lowering supports only extent tensors, not `shape.shape`
+ // types.
+ if (op.getType().isa<ShapeType>())
+ return failure();
+
+ auto loc = op.getLoc();
+ SmallVector<Value, 4> extentOperands;
+ for (auto extent : op.shape()) {
+ extentOperands.push_back(
+ rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
+ }
+ Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
+ Type indexTy = rewriter.getIndexType();
+ Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
+ return success();
+}
+
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
+ ConstShapeOpConverter,
BinaryOpConversion<MulOp, MulIOp>,
GetExtentOpConverter,
RankOpConverter,
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 9336402d86da..7f875f3bb19f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
// -----
+// Lower `const_shape` to `tensor_from_elements`.
+// CHECK-LABEL: @const_shape
+// CHECK-SAME: () -> tensor<?xindex>
+func @const_shape() -> tensor<?xindex> {
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[C2:.*]] = constant 2 : index
+ // CHECK: %[[C3:.*]] = constant 3 : index
+ // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
+ // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
+ // CHECK: return %[[RESULT]] : tensor<?xindex>
+ %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
+ return %shape : tensor<?xindex>
+}
+
+// -----
+
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
More information about the Mlir-commits
mailing list