[Mlir-commits] [mlir] dfcc098 - [MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`

Frederik Gossen llvmlistbot at llvm.org
Tue Jul 28 08:41:24 PDT 2020


Author: Frederik Gossen
Date: 2020-07-28T15:40:55Z
New Revision: dfcc09890a91b1085139fee175936b0e67824e47

URL: https://github.com/llvm/llvm-project/commit/dfcc09890a91b1085139fee175936b0e67824e47
DIFF: https://github.com/llvm/llvm-project/commit/dfcc09890a91b1085139fee175936b0e67824e47.diff

LOG: [MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`

Differential Revision: https://reviews.llvm.org/D82848

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index f239d1cfb4f0..b84b6ba3b5d6 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -103,6 +103,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
   return success();
 }
 
+namespace {
+class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
+public:
+  using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ConstShapeOpConverter::matchAndRewrite(
+    ConstShapeOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+
+  // For now, this lowering supports only extent tensors, not `shape.shape`
+  // types.
+  if (op.getType().isa<ShapeType>())
+    return failure();
+
+  auto loc = op.getLoc();
+  SmallVector<Value, 4> extentOperands;
+  for (auto extent : op.shape()) {
+    extentOperands.push_back(
+        rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
+  }
+  Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
+  Type indexTy = rewriter.getIndexType();
+  Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
+  return success();
+}
+
 namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns(
   patterns.insert<
       AnyOpConversion,
       BinaryOpConversion<AddOp, AddIOp>,
+      ConstShapeOpConverter,
       BinaryOpConversion<MulOp, MulIOp>,
       GetExtentOpConverter,
       RankOpConverter,

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 9336402d86da..7f875f3bb19f 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
 
 // -----
 
+// Lower `const_shape` to `tensor_from_elements`.
+// CHECK-LABEL: @const_shape
+// CHECK-SAME: () -> tensor<?xindex>
+func @const_shape() -> tensor<?xindex> {
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[C2:.*]] = constant 2 : index
+  // CHECK: %[[C3:.*]] = constant 3 : index
+  // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
+  // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: return %[[RESULT]] : tensor<?xindex>
+  %shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
+  return %shape : tensor<?xindex>
+}
+
+// -----
+
 // Lower `any` to its first operand.
 // CHECK-LABEL: @any_of_three
 // CHECK-SAME:  (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>


        


More information about the Mlir-commits mailing list