[Mlir-commits] [mlir] 6673c6c - [MLIR][Shape] Limit shape to standard lowerings to their supported types
Frederik Gossen
llvmlistbot at llvm.org
Wed Jul 29 06:57:16 PDT 2020
Author: Frederik Gossen
Date: 2020-07-29T13:56:52Z
New Revision: 6673c6cd82f79b76c1676ab1ab30a288286acb71
URL: https://github.com/llvm/llvm-project/commit/6673c6cd82f79b76c1676ab1ab30a288286acb71
DIFF: https://github.com/llvm/llvm-project/commit/6673c6cd82f79b76c1676ab1ab30a288286acb71.diff
LOG: [MLIR][Shape] Limit shape to standard lowerings to their supported types
The lowering does not support all types for its source operations. This change
makes the patterns fail in a well-defined manner.
Differential Revision: https://reviews.llvm.org/D84443
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 4deaa8cd2df3..41d4d90b33d3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -49,8 +49,14 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
LogicalResult
matchAndRewrite(SrcOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- typename SrcOpTy::Adaptor adaptor(operands);
- rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.lhs(), adaptor.rhs());
+ typename SrcOpTy::Adaptor transformed(operands);
+
+ // For now, only error-free types are supported by this lowering.
+ if (op.getType().template isa<SizeType>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<DstOpTy>(op, transformed.lhs(),
+ transformed.rhs());
return success();
}
};
@@ -85,27 +91,31 @@ class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
LogicalResult ShapeOfOpConversion::matchAndRewrite(
ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- ShapeOfOp::Adaptor transformed(operands);
- auto loc = op.getLoc();
- auto tensorVal = transformed.arg();
- auto tensorTy = tensorVal.getType();
+
+ // For now, only error-free types are supported by this lowering.
+ if (op.getType().isa<ShapeType>())
+ return failure();
// For unranked tensors `shape_of` lowers to `scf` and the pattern can be
// found in the corresponding pass.
+ ShapeOfOp::Adaptor transformed(operands);
+ Value tensorVal = transformed.arg();
+ Type tensorTy = tensorVal.getType();
if (tensorTy.isa<UnrankedTensorType>())
return failure();
// Build values for individual dimensions.
SmallVector<Value, 8> dimValues;
- auto rankedTensorTy = tensorTy.cast<RankedTensorType>();
+ RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
int64_t rank = rankedTensorTy.getRank();
+ auto loc = op.getLoc();
for (int64_t i = 0; i < rank; i++) {
if (rankedTensorTy.isDynamicDim(i)) {
- auto dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
+ Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
dimValues.push_back(dimVal);
} else {
int64_t dim = rankedTensorTy.getDimSize(i);
- auto dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
+ Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
dimValues.push_back(dimVal);
}
}
@@ -187,11 +197,18 @@ LogicalResult GetExtentOpConverter::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
GetExtentOp::Adaptor transformed(operands);
- // Derive shape extent directly from shape origin if possible.
- // This circumvents the necessity to materialize the shape in memory.
+ // For now, only error-free types are supported by this lowering.
+ if (op.getType().isa<SizeType>())
+ return failure();
+
+ // Derive shape extent directly from shape origin if possible. This
+ // circumvents the necessity to materialize the shape in memory.
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
- rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(), transformed.dim());
- return success();
+ if (shapeOfOp.arg().getType().isa<ShapedType>()) {
+ rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
+ transformed.dim());
+ return success();
+ }
}
rewriter.replaceOpWithNewOp<ExtractElementOp>(op, rewriter.getIndexType(),
@@ -241,7 +258,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
// Apply conversion.
auto module = getOperation();
- if (failed(applyFullConversion(module, target, patterns)))
+ if (failed(applyPartialConversion(module, target, patterns)))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 0e30cc2bdf56..3f19de9c52f0 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -13,6 +13,30 @@ func @binary_ops(%lhs : index, %rhs : index) {
// -----
+// Don't lower binary ops when they operate on `shape.size`.
+// CHECK-LABEL: @binary_ops_on_size
+// CHECK-SAME: (%[[LHS:.*]]: !shape.size, %[[RHS:.*]]: !shape.size)
+func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
+ // CHECK: shape.add %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
+ // CHECK: shape.mul %[[LHS]], %[[RHS]] : !shape.size, !shape.size -> !shape.size
+ %sum = shape.add %lhs, %rhs : !shape.size, !shape.size -> !shape.size
+ %prod = shape.mul %lhs, %rhs : !shape.size, !shape.size -> !shape.size
+ return
+}
+
+// -----
+
+// Don't lower `shape_of` with `shape.shape` type.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+ // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
+ %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
+ return
+}
+
+// -----
+
// Lower `shape_of` for statically shaped tensor.
// CHECK-LABEL: @shape_of_stat
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
@@ -55,6 +79,17 @@ func @rank(%shape : tensor<?xindex>) -> index {
// -----
+// Don't lower `get_extent` if it is of type `shape.size`.
+// CHECK-LABEL: @get_extent
+func @get_extent(%shape : tensor<?xindex>, %idx : !shape.size) -> !shape.size {
+ // CHECK: shape.get_extent
+ %result = shape.get_extent %shape, %idx
+ : tensor<?xindex>, !shape.size -> !shape.size
+ return %result : !shape.size
+}
+
+// -----
+
// Express `get_extent` as `std.dim` when it relies directly on the outcome of a
// `shape_of` operation.
// CHECK-LABEL: @get_extent_shape_of
More information about the Mlir-commits
mailing list