[Mlir-commits] [mlir] 8178e41 - [mlir] Type erase inputs to select statements in shape.broadcast lowering.

Tres Popp llvmlistbot at llvm.org
Sun Oct 11 12:58:17 PDT 2020


Author: Tres Popp
Date: 2020-10-11T21:58:06+02:00
New Revision: 8178e41dc1a395fc0a99f945bf27fbfca871d3e5

URL: https://github.com/llvm/llvm-project/commit/8178e41dc1a395fc0a99f945bf27fbfca871d3e5
DIFF: https://github.com/llvm/llvm-project/commit/8178e41dc1a395fc0a99f945bf27fbfca871d3e5.diff

LOG: [mlir] Type erase inputs to select statements in shape.broadcast lowering.

This is required or broadcasting with operands of different ranks will lead to
failures as the select op requires both possible outputs and its output type to
be the same.

Differential Revision: https://reviews.llvm.org/D89134

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index b1319a8cd386..2840761b66bd 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -99,10 +99,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
       rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
   Value greaterRank =
       rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+  auto erasedRankType =
+      RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  Value rankErasedLhs =
+      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+  Value rankErasedRhs =
+      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
   Value lesserRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
+      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
   Value greaterRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
+      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
 
   // Allocate stack memory for the broadcasted extent tensor.
   Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 6207486db821..4dc4868f053a 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -305,9 +305,9 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
 
 // -----
 
-// CHECK-LABEL: @broadcast
+// CHECK-LABEL: @broadcast_unknown_extents
 // CHECK-SAME:  (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
-func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
+func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
   // CHECK: %[[C0:.*]] = constant 0 : index
   // CHECK: %[[C1:.*]] = constant 1 : index
   // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
@@ -315,8 +315,10 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
   // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
   // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-  // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
-  // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
+  // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
+  // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
   // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
   // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
   // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
@@ -340,3 +342,43 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
       : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @broadcast_known_
diff erent_extents
+// CHECK-SAME:  (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>)
+func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
+  // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
+  // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+  // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+  // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+  // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
+  // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
+  // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
+  // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
+  // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
+  // CHECK:   %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
+  // CHECK:   %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
+  // CHECK:   %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
+  // CHECK:     %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+  // CHECK:     %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
+  // CHECK:     scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
+  // CHECK:   } else {
+  // CHECK:     scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
+  // CHECK:   }
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
+  %0 = shape.broadcast %a, %b
+      : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
+  return
+}


        


More information about the Mlir-commits mailing list