[Mlir-commits] [mlir] 511484f - [mlir] Add lowering for IsBroadcastable to Std dialect.
Tres Popp
llvmlistbot at llvm.org
Fri Oct 30 02:44:36 PDT 2020
Author: Tres Popp
Date: 2020-10-30T10:44:27+01:00
New Revision: 511484f27d923e16cbb88d3b00f800386a263cd5
URL: https://github.com/llvm/llvm-project/commit/511484f27d923e16cbb88d3b00f800386a263cd5
DIFF: https://github.com/llvm/llvm-project/commit/511484f27d923e16cbb88d3b00f800386a263cd5.diff
LOG: [mlir] Add lowering for IsBroadcastable to Std dialect.
Differential Revision: https://reviews.llvm.org/D90407
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0c300c4f9cc5..704b0cdb0324 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -207,6 +207,86 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite(
return success();
}
+namespace {
+struct IsBroadcastableOpConverter
+ : public OpConversionPattern<IsBroadcastableOp> {
+ using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
+ IsBroadcastableOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+ // on shapes.
+ IsBroadcastableOp::Adaptor transformed(operands);
+ if (transformed.lhs().getType().isa<ShapeType>() ||
+ transformed.rhs().getType().isa<ShapeType>())
+ return failure();
+
+ auto loc = op.getLoc();
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+ // Find smaller and greater rank and extent tensor.
+ Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+ Value lhsRankULE =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+ Type indexTy = rewriter.getIndexType();
+ Value lesserRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+ Value greaterRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+ auto erasedRankType =
+ RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+ Value rankErasedLhs =
+ rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+ Value rankErasedRhs =
+ rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+ Value lesserRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
+ Value greaterRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
+ Value rankDiff =
+ rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
+ Type i1Ty = rewriter.getI1Type();
+ Value init =
+ rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
+
+ // Determine if all overlapping extents are broadcastable.
+ auto reduceResult = rewriter.create<ForOp>(
+ loc, rankDiff, greaterRank, one, ValueRange{init},
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+ Value greaterRankOperandExtent =
+ b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
+ Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
+ loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
+ Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+ Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+ loc, lesserRankOperand, ValueRange{ivShifted});
+ Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
+ loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
+ Value extentsAreEqual =
+ b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
+ lesserRankOperandExtent);
+ Value broadcastableExtents = b.create<AndOp>(
+ loc, iterArgs[0],
+ b.create<OrOp>(loc,
+ b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
+ lesserRankOperandExtentIsOne),
+ extentsAreEqual));
+ b.create<scf::YieldOp>(loc, broadcastableExtents);
+ });
+
+ rewriter.replaceOp(op, reduceResult.results().front());
+ return success();
+}
+
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -522,6 +602,7 @@ void mlir::populateShapeToStandardConversionPatterns(
BroadcastOpConverter,
ConstShapeOpConverter,
ConstSizeOpConversion,
+ IsBroadcastableOpConverter,
GetExtentOpConverter,
RankOpConverter,
ReduceOpConverter,
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d827b6f4cc2e..56594d529e4d 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -382,3 +382,41 @@ func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return
}
+
+// -----
+
+func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
+ %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
+ return %0 : i1
+}
+
+// CHECK-LABEL: func @try_is_broadcastable(
+// CHECK-SAME: %[[LHS:.*]]: tensor<3xindex>,
+// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> i1 {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
+// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
+// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
+// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK: %[[TRUE:.*]] = constant true
+// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
+// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
+// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
+// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
+// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
+// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
+// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
+// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
+// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
+// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
+// CHECK: }
+// CHECK: return %[[ALL_RESULT]] : i1
+// CHECK: }
More information about the Mlir-commits
mailing list