[Mlir-commits] [mlir] 5fc34fa - [MLIR][Shape] Limit shape to SCF lowering patterns to their supported types
Frederik Gossen
llvmlistbot at llvm.org
Wed Jul 29 07:54:30 PDT 2020
Author: Frederik Gossen
Date: 2020-07-29T14:54:09Z
New Revision: 5fc34fafa72e311763b43f3041bb7113c8ae5b6f
URL: https://github.com/llvm/llvm-project/commit/5fc34fafa72e311763b43f3041bb7113c8ae5b6f
DIFF: https://github.com/llvm/llvm-project/commit/5fc34fafa72e311763b43f3041bb7113c8ae5b6f.diff
LOG: [MLIR][Shape] Limit shape to SCF lowering patterns to their supported types
Differential Revision: https://reviews.llvm.org/D84444
Added:
Modified:
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 824db2685a77..0101c9e7fdc0 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -121,7 +121,7 @@ LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands.
- if (!op.shape().getType().isa<RankedTensorType>())
+ if (op.shape().getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
@@ -171,12 +171,15 @@ class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
- ShapeOfOp::Adaptor transformed(operands);
- Value arg = transformed.arg();
- Type argTy = arg.getType();
+ // For now, this lowering supports only error-free arguments.
+ if (op.getType().isa<ShapeType>())
+ return failure();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
+ ShapeOfOp::Adaptor transformed(operands);
+ Value arg = transformed.arg();
+ Type argTy = arg.getType();
if (argTy.isa<RankedTensorType>())
return failure();
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 441b2e92cc3d..97d2bce5a094 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
// -----
+// Don't lower `shape_of` for result type of `shape.shape`.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of(%arg : tensor<*xf32>) {
+ // CHECK: shape.shape
+ %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
+ return
+}
+
+// -----
+
// Lower `shape_of` for unranked tensors.
// CHECK-LABEL: @shape_of_unranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
- // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
- // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
- // CHECK-DAG: %[[C0:.*]] = constant 0 : index
- // CHECK-DAG: %[[C1:.*]] = constant 1 : index
- // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
- // CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
- // CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
- // CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
- // CHECK: }
- // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
- // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+ // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+ // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+ // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+ // CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
+ // CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+ // CHECK: }
+ // CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
+ // CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
return
}
More information about the Mlir-commits
mailing list