[Mlir-commits] [mlir] 3842d4b - Make shape.is_broadcastable/shape.cstr_broadcastable nary
Tres Popp
llvmlistbot at llvm.org
Mon Feb 15 07:05:48 PST 2021
Author: Tres Popp
Date: 2021-02-15T16:05:32+01:00
New Revision: 3842d4b6791f6fbd67a1d12806f05a05654728cf
URL: https://github.com/llvm/llvm-project/commit/3842d4b6791f6fbd67a1d12806f05a05654728cf
DIFF: https://github.com/llvm/llvm-project/commit/3842d4b6791f6fbd67a1d12806f05a05654728cf.diff
LOG: Make shape.is_broadcastable/shape.cstr_broadcastable nary
This corresponds with the previous work to make shape.broadcast nary.
Additionally, simplify the ConvertShapeConstraints pass. It now doesn't
lower an implicit shape.is_broadcastable. This is still the same in
combination with shape-to-standard when the 2 passes are used in either
order.
Differential Revision: https://reviews.llvm.org/D96401
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 20b0706be367..b50a6f99e04c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -190,11 +190,12 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
let assemblyFormat = "$input attr-dict `:` type($input)";
}
-def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
- let summary = "Determines if 2 shapes can be successfully broadcasted";
+def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
+ [Commutative, InferTypeOpInterface]> {
+ let summary = "Determines if 2+ shapes can be successfully broadcasted";
let description = [{
- Given two input shapes or extent tensors, return a predicate specifying if
- they are broadcastable. This broadcastable follows the same logic as what
+ Given multiple input shapes or extent tensors, return a predicate specifying
+ if they are broadcastable. This broadcastable follows the same logic as what
shape.broadcast documents.
Concretely, shape.is_broadcastable returning true implies that
@@ -209,11 +210,28 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
```
}];
- let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
- Shape_ShapeOrExtentTensorType:$rhs);
+ let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs I1:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+ let builders = [
+ OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+ [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+ ];
+ let extraClassDeclaration = [{
+ // TODO: This should really be automatic. Figure out how to not need this defined.
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+ inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
+ /*width=*/1));
+ return success();
+ };
+ }];
+
+ let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+ let verifier = [{ return ::verify(*this); }];
+
}
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
@@ -692,11 +710,12 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
-def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
- let summary = "Determines if 2 shapes can be successfully broadcasted";
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
+ [Commutative, InferTypeOpInterface]> {
+ let summary = "Determines if 2+ shapes can be successfully broadcasted";
let description = [{
- Given two input shapes or extent tensors, return a witness specifying if
- they are broadcastable. This broadcastable follows the same logic as what
+ Given input shapes or extent tensors, return a witness specifying if they
+ are broadcastable. This broadcastable follows the same logic as what
shape.broadcast documents.
"cstr" operations represent runtime assertions.
@@ -708,14 +727,30 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
```
}];
- let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
- Shape_ShapeOrExtentTensorType:$rhs);
+ let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
let results = (outs Shape_WitnessType:$result);
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+ let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+
+ let builders = [
+ OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+ [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ // TODO: This should really be automatic. Figure out how to not need this defined.
+ static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+ ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+ ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+ inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
+ return success();
+ };
+ }];
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let verifier = [{ return ::verify(*this); }];
}
def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index 65b1fa1096d6..e9d31ac93438 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -19,77 +19,8 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
-
namespace {
-class ConvertCstrBroadcastableOp
- : public OpRewritePattern<shape::CstrBroadcastableOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
- PatternRewriter &rewriter) const override {
- if (op.getType().isa<shape::ShapeType>() ||
- op.lhs().getType().isa<shape::ShapeType>() ||
- op.rhs().getType().isa<shape::ShapeType>()) {
- return rewriter.notifyMatchFailure(
- op, "cannot convert error-propagating shapes");
- }
-
- auto loc = op.getLoc();
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-
- // Find smaller and greater rank and extent tensor.
- Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
- Value lhsRankULE =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
- Type indexTy = rewriter.getIndexType();
- Value lesserRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
- Value greaterRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
- Value lesserRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
- Value greaterRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
-
- Value rankDiff =
- rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
-
- // Generate code to compare the shapes extent by extent, and emit errors for
- // non-broadcast-compatible shapes.
- // Two extents are broadcast-compatible if
- // 1. they are both equal, or
- // 2. at least one of them is 1.
-
- rewriter.create<scf::ForOp>(
- loc, rankDiff, greaterRank, one, llvm::None,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
- Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, greaterRankOperand, ValueRange{iv});
- Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, lesserRankOperand, ValueRange{ivShifted});
-
- Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
- loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
- Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
- loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
- Value extentsAgree =
- b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
- lesserRankOperandExtent);
- auto broadcastIsValid =
- b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
- b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
- lesserRankOperandExtentIsOne));
- b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
- b.create<scf::YieldOp>(loc);
- });
-
- rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
- return success();
- }
-};
+#include "ShapeToStandard.cpp.inc"
} // namespace
namespace {
@@ -107,7 +38,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
void mlir::populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<ConvertCstrBroadcastableOp>(ctx);
+ patterns.insert<CstrBroadcastableToRequire>(ctx);
patterns.insert<ConvertCstrRequireOp>(ctx);
}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 3c83b4371df3..5f4396d73d88 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -237,63 +237,84 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
// on shapes.
IsBroadcastableOp::Adaptor transformed(operands);
- if (transformed.lhs().getType().isa<ShapeType>() ||
- transformed.rhs().getType().isa<ShapeType>())
+ if (!llvm::all_of(op.shapes(),
+ [](Value v) { return !v.getType().isa<ShapeType>(); }))
return failure();
auto loc = op.getLoc();
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ ImplicitLocOpBuilder lb(loc, rewriter);
+ Value zero = lb.create<ConstantIndexOp>(0);
+ Value one = lb.create<ConstantIndexOp>(1);
+ Type indexTy = lb.getIndexType();
+
+ // Save all the ranks for bounds checking. Because this is a tensor
+ // representing the shape extents, the rank is the extent of the only
+ // dimension in the tensor.
+ SmallVector<Value> ranks, rankDiffs;
+ llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+ return lb.create<DimOp>(v, zero);
+ }));
+
+ // Find the maximum rank
+ Value maxRank = ranks.front();
+ for (Value v : llvm::drop_begin(ranks, 1)) {
+ Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
+ maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
+ }
+
+ // Calculate the
diff erence of ranks and the maximum rank for later offsets.
+ llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
+ return lb.create<SubIOp>(indexTy, maxRank, v);
+ }));
- // Find smaller and greater rank and extent tensor.
- Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
- Value lhsRankULE =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
- Type indexTy = rewriter.getIndexType();
- Value lesserRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
- Value greaterRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
- auto erasedRankType =
- RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
- Value rankErasedLhs =
- rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
- Value rankErasedRhs =
- rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
- Value lesserRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
- Value greaterRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
- Value rankDiff =
- rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
Type i1Ty = rewriter.getI1Type();
- Value init =
+ Value trueVal =
rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
- // Determine if all overlapping extents are broadcastable.
- auto reduceResult = rewriter.create<ForOp>(
- loc, rankDiff, greaterRank, one, ValueRange{init},
+ auto reduceResult = lb.create<ForOp>(
+ loc, zero, maxRank, one, ValueRange{trueVal},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
- Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, greaterRankOperand, ValueRange{iv});
- Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
- loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
- Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, lesserRankOperand, ValueRange{ivShifted});
- Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
- loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
- Value extentsAreEqual =
- b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
- lesserRankOperandExtent);
- Value broadcastableExtents = b.create<AndOp>(
- loc, iterArgs[0],
- b.create<OrOp>(loc,
- b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
- lesserRankOperandExtentIsOne),
- extentsAreEqual));
- b.create<scf::YieldOp>(loc, broadcastableExtents);
+ // Find a non-1 dim, if it exists. Note that the first part of this
+ // could reuse the Broadcast lowering entirely, but we redo the work
+ // here to make optimizations easier between the two loops.
+ Value broadcastedDim = getBroadcastedDim(
+ ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
+
+ Value broadcastable = iterArgs[0];
+ for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
+ Value shape, rankDiff;
+ std::tie(shape, rankDiff) = tup;
+ Value outOfBounds =
+ b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
+ broadcastable =
+ b.create<IfOp>(
+ loc, TypeRange{i1Ty}, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ // Non existent dimensions are always broadcastable
+ b.create<scf::YieldOp>(loc, broadcastable);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // Every value needs to be either 1, or the same non-1
+ // value to be broadcastable in this dim.
+ Value operandDimension =
+ b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+ Value dimensionExtent = b.create<tensor::ExtractOp>(
+ loc, shape, ValueRange{operandDimension});
+
+ Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+ dimensionExtent, one);
+ Value equalBroadcasted =
+ b.create<CmpIOp>(loc, CmpIPredicate::eq,
+ dimensionExtent, broadcastedDim);
+ Value result = b.create<AndOp>(
+ loc, broadcastable,
+ b.create<OrOp>(loc, equalOne, equalBroadcasted));
+ b.create<scf::YieldOp>(loc, result);
+ })
+ .getResult(0);
+ }
+
+ b.create<scf::YieldOp>(loc, broadcastable);
});
rewriter.replaceOp(op, reduceResult.results().front());
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
index a5eaa7a2a889..aac3789c3b58 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
@@ -19,9 +19,9 @@ def BroadcastableStringAttr : NativeCodeCall<[{
$_builder.getStringAttr("required broadcastable shapes")
}]>;
-def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
+def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes),
(Shape_CstrRequireOp
- (Shape_IsBroadcastableOp $LHS, $RHS),
+ (Shape_IsBroadcastableOp $shapes),
(BroadcastableStringAttr))>;
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8c75bdc9aa16..058c0c58dda2 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -491,6 +491,10 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
}
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+ // TODO: Add folding for the nary case
+ if (operands.size() != 2)
+ return nullptr;
+
// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
@@ -512,9 +516,9 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// Lastly, see if folding can be completed based on what constraints are known
// on the input shapes.
SmallVector<int64_t, 6> lhsShape, rhsShape;
- if (failed(getShapeVec(lhs(), lhsShape)))
+ if (failed(getShapeVec(shapes()[0], lhsShape)))
return nullptr;
- if (failed(getShapeVec(rhs(), rhsShape)))
+ if (failed(getShapeVec(shapes()[1], rhsShape)))
return nullptr;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
@@ -525,6 +529,13 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
+static LogicalResult verify(CstrBroadcastableOp op) {
+ // Ensure that AssumingAllOp contains at least one operand
+ if (op.getNumOperands() < 2)
+ return op.emitOpError("required at least 2 input shapes");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CstrEqOp
//===----------------------------------------------------------------------===//
@@ -723,6 +734,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
}
}
+//===----------------------------------------------------------------------===//
+// IsBroadcastableOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(IsBroadcastableOp op) {
+ // Ensure that AssumingAllOp contains at least one operand
+ if (op.getNumOperands() < 2)
+ return op.emitOpError("required at least 2 input shapes");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 45c699baece3..8f847b1b28c5 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -18,8 +18,9 @@ def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
(replaceWithValue $args),
[(HasSingleElement $args)]>;
-def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
- (Shape_ConstWitnessOp ConstBoolAttrTrue)>;
+def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes),
+ (Shape_ConstWitnessOp ConstBoolAttrTrue),
+ [(AllInputShapesEq $shapes)]>;
def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 688b9fbffba7..5b47d9453261 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -4,28 +4,9 @@
// CHECK-LABEL: func @cstr_broadcastable(
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[RET:.*]] = shape.const_witness true
-// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
-// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
-// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
-// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
-// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
-// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
-// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
-// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
-// CHECK: %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1
-// CHECK: %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1
-// CHECK: assert %[[BROADCAST_IS_VALID]], "invalid broadcast"
-// CHECK: }
+// CHECK: %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
+// CHECK: assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
// CHECK: return %[[RET]] : !shape.witness
// CHECK: }
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 329e86848aa9..385e296177ad 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -305,77 +305,184 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
// -----
-func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
- %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
+func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> i1 {
+ %0 = shape.is_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
return %0 : i1
}
-
-// CHECK-LABEL: func @try_is_broadcastable(
-// CHECK-SAME: %[[LHS:.*]]: tensor<3xindex>,
-// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> i1 {
+// CHECK-LABEL: @try_is_broadcastable
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
-// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
-// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
-// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
-// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
// CHECK: %[[TRUE:.*]] = constant true
-// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
-// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
-// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index
-// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
-// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
-// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index
-// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
-// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
-// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
-// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
-// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
-// CHECK: }
-// CHECK: return %[[ALL_RESULT]] : i1
-// CHECK: }
+// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK: %[[C1_0:.*]] = constant 1 : index
+// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK: scf.yield %[[C1_0]] : index
+// CHECK: } else {
+// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
+// CHECK: }
+// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK: scf.yield %[[DIM0]] : index
+// CHECK: } else {
+// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK: }
+// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK: scf.yield %[[DIM1]] : index
+// CHECK: } else {
+// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK: }
+// CHECK: %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
+// CHECK: scf.yield %[[ALL_SO_FAR]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
+// CHECK: scf.yield %[[REDUCTION_0]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
+// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: scf.yield %[[FINAL_RESULT]] : i1
// -----
-func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
- %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
+func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> !shape.witness {
+ %0 = shape.cstr_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
return %0 : !shape.witness
}
-
// CHECK-LABEL: func @broadcast(
-// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>)
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
-// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
-// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
// CHECK: %[[TRUE:.*]] = constant true
-// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
-// CHECK: %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
-// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index
-// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
-// CHECK: %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
-// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index
-// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
-// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
-// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
-// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
-// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
-// CHECK: }
+// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK: %[[C1_0:.*]] = constant 1 : index
+// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK: scf.yield %[[C1_0]] : index
+// CHECK: } else {
+// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
+// CHECK: }
+// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK: scf.yield %[[DIM0]] : index
+// CHECK: } else {
+// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK: }
+// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK: scf.yield %[[DIM1]] : index
+// CHECK: } else {
+// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
+// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK: }
+// CHECK: %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
+// CHECK: scf.yield %[[ALL_SO_FAR]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
+// CHECK: scf.yield %[[REDUCTION_0]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
+// CHECK: scf.yield %[[SECOND_REDUCTION]] : i1
+// CHECK: } else {
+// CHECK: %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK: %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index
+// CHECK: %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
+// CHECK: %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
+// CHECK: %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1
+// CHECK: scf.yield %[[AND_REDUCTION]] : i1
+// CHECK: }
+// CHECK: scf.yield %[[FINAL_RESULT]] : i1
+
// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
// CHECK: return %[[RESULT]] : !shape.witness
// CHECK: }
diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d2f5af2f7b30..d685e6766072 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -246,3 +246,21 @@ func @fn(%arg: !shape.value_shape) -> !shape.shape {
// expected-error at +1 {{@fn not found}}
module attributes {shape.lib = @fn} { }
+
+// -----
+
+func @fn(%arg: !shape.shape) -> i1 {
+ // expected-error at +1 {{required at least 2 input shapes}}
+ %0 = shape.is_broadcastable %arg : !shape.shape
+ return %0 : i1
+}
+
+// -----
+
+func @fn(%arg: !shape.shape) -> !shape.witness {
+ // expected-error at +1 {{required at least 2 input shapes}}
+ %0 = shape.cstr_broadcastable %arg : !shape.shape
+ return %0 : !shape.witness
+}
+
+
More information about the Mlir-commits
mailing list