[Mlir-commits] [mlir] 11492be - [MLIR][Shape] Lower `shape.broadcast` to `scf`

Frederik Gossen llvmlistbot at llvm.org
Mon Aug 3 01:20:36 PDT 2020


Author: Frederik Gossen
Date: 2020-08-03T08:20:14Z
New Revision: 11492be9d72d4215ac2f61626264da05fee35e78

URL: https://github.com/llvm/llvm-project/commit/11492be9d72d4215ac2f61626264da05fee35e78
DIFF: https://github.com/llvm/llvm-project/commit/11492be9d72d4215ac2f61626264da05fee35e78.diff

LOG: [MLIR][Shape] Lower `shape.broadcast` to `scf`

Differential Revision: https://reviews.llvm.org/D85027

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
    mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index a6c667f5641c..ae326c5c513e 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -19,6 +19,98 @@ using namespace mlir;
 using namespace mlir::shape;
 using namespace mlir::scf;
 
+namespace {
+struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
+  using OpConversionPattern<BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult BroadcastOpConverter::matchAndRewrite(
+    BroadcastOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+  // on shapes.
+  if (op.getType().isa<ShapeType>())
+    return failure();
+
+  assert(!op.lhs().getType().isa<ShapeType>() &&
+         !op.rhs().getType().isa<ShapeType>());
+  auto loc = op.getLoc();
+  BroadcastOp::Adaptor transformed(operands);
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+  // Find smaller and greater rank and extent tensor.
+  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+  Value lhsSmaller =
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+  Type indexTy = rewriter.getIndexType();
+  Type extentTensorTy = op.getType();
+  auto ifOp = rewriter.create<IfOp>(
+      loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+      lhsSmaller,
+      [&](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
+                                               rhsRank, transformed.rhs()});
+      },
+      [&](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
+                                               lhsRank, transformed.lhs()});
+      });
+  Value smallerRank = ifOp.getResult(0);
+  Value smallerOperand = ifOp.getResult(1);
+  Value greaterRank = ifOp.getResult(2);
+  Value greaterOperand = ifOp.getResult(3);
+
+  // Allocate stack memory for the broadcasted extent tensor.
+  Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+  Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
+
+  // Copy extents from greater operand that are not challenged.
+  Value rankDiff =
+      rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+  rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
+                         [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+                           Value extent = b.create<ExtractElementOp>(
+                               loc, greaterOperand, ValueRange{iv});
+                           b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+                           b.create<scf::YieldOp>(loc);
+                         });
+
+  // Determine remaining broadcasted extents.
+  rewriter.create<ForOp>(
+      loc, rankDiff, greaterRank, one, llvm::None,
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+        Value greaterOperandExtent =
+            b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+        Value greaterOperandExtentIsOne =
+            b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
+        auto ifOp = b.create<IfOp>(
+            loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
+            [&](OpBuilder &b, Location loc) {
+              Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+              Value smallerOperandExtent = b.create<ExtractElementOp>(
+                  loc, smallerOperand, ValueRange{ivShifted});
+              b.create<scf::YieldOp>(loc, smallerOperandExtent);
+            },
+            [&](OpBuilder &b, Location loc) {
+              b.create<scf::YieldOp>(loc, greaterOperandExtent);
+            });
+        Value extent = ifOp.getResult(0);
+        b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+        b.create<scf::YieldOp>(loc);
+      });
+
+  // Load broadcasted shape as an extent tensor.
+  rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
+  return success();
+}
+
 namespace {
 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
 /// only defined on `tensor<?xindex>` operands. The test for equality first
@@ -223,7 +315,6 @@ void ConvertShapeToSCFPass::runOnFunction() {
   // Setup target legality.
   ConversionTarget target(getContext());
   target.addLegalDialect<SCFDialect, StandardOpsDialect>();
-  target.addLegalOp<ModuleOp, FuncOp>();
 
   // Apply conversion.
   if (failed(applyPartialConversion(getFunction(), target, patterns)))
@@ -234,6 +325,7 @@ void mlir::populateShapeToSCFConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   // clang-format off
   patterns.insert<
+      BroadcastOpConverter,
       ShapeEqOpConverter,
       ReduceOpConverter,
       ShapeOfOpConverter>(ctx);

diff  --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 768a627208b8..cc384496dff0 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -80,3 +80,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
   %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
   return %result : i1
 }
+
+// -----
+
+// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
+// CHECK-LABEL: @broadcast
+func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
+  // CHECK: shape.broadcast
+  %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
+  return %c : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast
+// CHECK-SAME:  (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
+  // CHECK: %[[C0:.*]] = constant 0 : index
+  // CHECK: %[[C1:.*]] = constant 1 : index
+  // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+  // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
+  // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+  // CHECK:   scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+  // CHECK: } else {
+  // CHECK:   scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+  // CHECK: }
+  // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
+  // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+  // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
+  // CHECK:   %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
+  // CHECK:   %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+  // CHECK:   %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
+  // CHECK:     %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+  // CHECK:     %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
+  // CHECK:     scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+  // CHECK:   } else {
+  // CHECK:     scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+  // CHECK:   }
+  // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+  // CHECK: }
+  // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
+  %0 = shape.broadcast %a, %b
+      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+  return
+}
+


        


More information about the Mlir-commits mailing list