[Mlir-commits] [mlir] ac3e5c4 - [MLIR][Shape] Lower `shape.shape_of` to standard dialect
Frederik Gossen
llvmlistbot at llvm.org
Fri Jun 19 08:22:09 PDT 2020
Author: Frederik Gossen
Date: 2020-06-19T15:21:13Z
New Revision: ac3e5c4d93fbe7fb2db3c745c721aff41cc1b851
URL: https://github.com/llvm/llvm-project/commit/ac3e5c4d93fbe7fb2db3c745c721aff41cc1b851
DIFF: https://github.com/llvm/llvm-project/commit/ac3e5c4d93fbe7fb2db3c745c721aff41cc1b851.diff
LOG: [MLIR][Shape] Lower `shape.shape_of` to standard dialect
Lower `shape.shape_of` to standard dialect.
This lowering supports statically and dynamically shaped tensors.
Support for unranked tensors will be added as part of the lowering to `scf`.
Differential Revision: https://reviews.llvm.org/D82098
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index d02f5e3de116..6a02bdc2c286 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -38,6 +38,45 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
}
};
+class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
+public:
+ using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ ShapeOfOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+ auto tensorVal = transformed.arg();
+ auto tensorTy = tensorVal.getType();
+
+ // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
+ // found in the corresponding pass.
+ if (tensorTy.isa<UnrankedTensorType>())
+ return failure();
+
+ // Build values for individual dimensions.
+ SmallVector<Value, 8> dimValues;
+ auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+ int64_t rank = rankedTensorTy.getRank();
+ for (int64_t i = 0; i < rank; i++) {
+ if (rankedTensorTy.isDynamicDim(i)) {
+ auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+ dimValues.push_back(dimVal);
+ } else {
+ int64_t dim = rankedTensorTy.getDimSize(i);
+ auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+ dimValues.push_back(dimVal);
+ }
+ }
+
+ // Materialize shape as ranked tensor.
+ rewriter.replaceOpWithNewOp<TensorFromElementsOp>(op.getOperation(),
+ dimValues);
+ return success();
+ }
+};
+
class ConstSizeOpConverter : public OpConversionPattern<ConstSizeOp> {
public:
using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
@@ -107,7 +146,8 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
BinaryOpConversion<AddOp, AddIOp>,
BinaryOpConversion<MulOp, MulIOp>,
- ConstSizeOpConverter>(ctx);
+ ConstSizeOpConverter,
+ ShapeOfOpConversion>(ctx);
// clang-format on
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 1caf0051f37b..bfe3c2b599c5 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -86,3 +86,32 @@ func @size_const() -> !shape.size {
}
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: return %[[C1]] : index
+
+// -----
+
+// Lower `shape_of` for statically shaped tensor.
+// CHECK-LABEL: @shape_of_stat
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+ // CHECK-DAG: %[[C3:.*]] = constant 3 : index
+ // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+ %shape = shape.shape_of %arg : tensor<1x2x3xf32>
+ return
+}
+
+// -----
+
+// Lower `shape_of` for dynamically shaped tensor.
+// CHECK-LABEL: @shape_of_dyn
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
+func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C5:.*]] = constant 5 : index
+ // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+ // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
+ // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+ %shape = shape.shape_of %arg : tensor<1x5x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list