[Mlir-commits] [mlir] 11492be - [MLIR][Shape] Lower `shape.broadcast` to `scf`
Frederik Gossen
llvmlistbot at llvm.org
Mon Aug 3 01:20:36 PDT 2020
Author: Frederik Gossen
Date: 2020-08-03T08:20:14Z
New Revision: 11492be9d72d4215ac2f61626264da05fee35e78
URL: https://github.com/llvm/llvm-project/commit/11492be9d72d4215ac2f61626264da05fee35e78
DIFF: https://github.com/llvm/llvm-project/commit/11492be9d72d4215ac2f61626264da05fee35e78.diff
LOG: [MLIR][Shape] Lower `shape.broadcast` to `scf`
Differential Revision: https://reviews.llvm.org/D85027
Added:
Modified:
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
index a6c667f5641c..ae326c5c513e 100644
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
@@ -19,6 +19,98 @@ using namespace mlir;
using namespace mlir::shape;
using namespace mlir::scf;
+namespace {
+struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
+ using OpConversionPattern<BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult BroadcastOpConverter::matchAndRewrite(
+ BroadcastOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+ // on shapes.
+ if (op.getType().isa<ShapeType>())
+ return failure();
+
+ assert(!op.lhs().getType().isa<ShapeType>() &&
+ !op.rhs().getType().isa<ShapeType>());
+ auto loc = op.getLoc();
+ BroadcastOp::Adaptor transformed(operands);
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+ // Find smaller and greater rank and extent tensor.
+ Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+ Value lhsSmaller =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+ Type indexTy = rewriter.getIndexType();
+ Type extentTensorTy = op.getType();
+ auto ifOp = rewriter.create<IfOp>(
+ loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+ lhsSmaller,
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
+ rhsRank, transformed.rhs()});
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
+ lhsRank, transformed.lhs()});
+ });
+ Value smallerRank = ifOp.getResult(0);
+ Value smallerOperand = ifOp.getResult(1);
+ Value greaterRank = ifOp.getResult(2);
+ Value greaterOperand = ifOp.getResult(3);
+
+ // Allocate stack memory for the broadcasted extent tensor.
+ Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+ Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
+
+ // Copy extents from greater operand that are not challenged.
+ Value rankDiff =
+ rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+ rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+ Value extent = b.create<ExtractElementOp>(
+ loc, greaterOperand, ValueRange{iv});
+ b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+ b.create<scf::YieldOp>(loc);
+ });
+
+ // Determine remaining broadcasted extents.
+ rewriter.create<ForOp>(
+ loc, rankDiff, greaterRank, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+ Value greaterOperandExtent =
+ b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+ Value greaterOperandExtentIsOne =
+ b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
+ auto ifOp = b.create<IfOp>(
+ loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
+ [&](OpBuilder &b, Location loc) {
+ Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+ Value smallerOperandExtent = b.create<ExtractElementOp>(
+ loc, smallerOperand, ValueRange{ivShifted});
+ b.create<scf::YieldOp>(loc, smallerOperandExtent);
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, greaterOperandExtent);
+ });
+ Value extent = ifOp.getResult(0);
+ b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+ b.create<scf::YieldOp>(loc);
+ });
+
+ // Load broadcasted shape as an extent tensor.
+ rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
+ return success();
+}
+
namespace {
/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
/// only defined on `tensor<?xindex>` operands. The test for equality first
@@ -223,7 +315,6 @@ void ConvertShapeToSCFPass::runOnFunction() {
// Setup target legality.
ConversionTarget target(getContext());
target.addLegalDialect<SCFDialect, StandardOpsDialect>();
- target.addLegalOp<ModuleOp, FuncOp>();
// Apply conversion.
if (failed(applyPartialConversion(getFunction(), target, patterns)))
@@ -234,6 +325,7 @@ void mlir::populateShapeToSCFConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
patterns.insert<
+ BroadcastOpConverter,
ShapeEqOpConverter,
ReduceOpConverter,
ShapeOfOpConverter>(ctx);
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
index 768a627208b8..cc384496dff0 100644
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
@@ -80,3 +80,53 @@ func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
%result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
return %result : i1
}
+
+// -----
+
+// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
+// CHECK-LABEL: @broadcast
+func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
+ // CHECK: shape.broadcast
+ %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
+ return %c : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast
+// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
+ // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+ // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+ // CHECK: } else {
+ // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+ // CHECK: }
+ // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
+ // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+ // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
+ // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+ // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+ // CHECK: }
+ // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
+ // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+ // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+ // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
+ // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+ // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
+ // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+ // CHECK: } else {
+ // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+ // CHECK: }
+ // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+ // CHECK: }
+ // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
+ %0 = shape.broadcast %a, %b
+ : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ return
+}
+
More information about the Mlir-commits
mailing list