[Mlir-commits] [mlir] 5fc34fa - [MLIR][Shape] Limit shape to SCF lowering patterns to their supported types

Frederik Gossen llvmlistbot at llvm.org
Wed Jul 29 07:54:30 PDT 2020


Author: Frederik Gossen
Date: 2020-07-29T14:54:09Z
New Revision: 5fc34fafa72e311763b43f3041bb7113c8ae5b6f

URL: https://github.com/llvm/llvm-project/commit/5fc34fafa72e311763b43f3041bb7113c8ae5b6f
DIFF: https://github.com/llvm/llvm-project/commit/5fc34fafa72e311763b43f3041bb7113c8ae5b6f.diff

LOG: [MLIR][Shape] Limit shape to SCF lowering patterns to their supported types

Differential Revision: https://reviews.llvm.org/D84444

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index 824db2685a77..0101c9e7fdc0 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -121,7 +121,7 @@ LogicalResult
 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
                                    ConversionPatternRewriter &rewriter) const {
   // For now, this lowering is only defined on `tensor<?xindex>` operands.
-  if (!op.shape().getType().isa<RankedTensorType>())
+  if (op.shape().getType().isa<ShapeType>())
     return failure();
 
   auto loc = op.getLoc();
@@ -171,12 +171,15 @@ class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
 LogicalResult
 ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
                                     ConversionPatternRewriter &rewriter) const {
-  ShapeOfOp::Adaptor transformed(operands);
-  Value arg = transformed.arg();
-  Type argTy = arg.getType();
+  // For now, this lowering supports only error-free arguments.
+  if (op.getType().isa<ShapeType>())
+    return failure();
 
   // For ranked tensors `shape_of` lowers to `std` and the pattern can be
   // found in the corresponding pass.
+  ShapeOfOp::Adaptor transformed(operands);
+  Value arg = transformed.arg();
+  Type argTy = arg.getType();
   if (argTy.isa<RankedTensorType>())
     return failure();
 

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 441b2e92cc3d..97d2bce5a094 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
 
 // -----
 
+// Don't lower `shape_of` for result type of `shape.shape`.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of(%arg : tensor<*xf32>) {
+  // CHECK: shape.shape
+  %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
+  return
+}
+
+// -----
+
 // Lower `shape_of` for unranked tensors.
 // CHECK-LABEL: @shape_of_unranked
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
 func @shape_of_unranked(%arg : tensor<*xf32>) {
-  // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
-  // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
-  // CHECK-DAG: %[[C0:.*]] = constant 0 : index
-  // CHECK-DAG: %[[C1:.*]] = constant 1 : index
-  // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
-  // CHECK-DAG:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
-  // CHECK-DAG:   %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
-  // CHECK-DAG:   store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
-  // CHECK:     }
-  // CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
-  // CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
+  // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+  // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+  // CHECK:   %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+  // CHECK:   %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
+  // CHECK:   store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
+  // CHECK: }
+  // CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
+  // CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
   %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
   return
 }


        


More information about the Mlir-commits mailing list