[Mlir-commits] [mlir] 3842d4b - Make shape.is_broadcastable/shape.cstr_broadcastable nary

Tres Popp llvmlistbot at llvm.org
Mon Feb 15 07:05:48 PST 2021


Author: Tres Popp
Date: 2021-02-15T16:05:32+01:00
New Revision: 3842d4b6791f6fbd67a1d12806f05a05654728cf

URL: https://github.com/llvm/llvm-project/commit/3842d4b6791f6fbd67a1d12806f05a05654728cf
DIFF: https://github.com/llvm/llvm-project/commit/3842d4b6791f6fbd67a1d12806f05a05654728cf.diff

LOG: Make shape.is_broadcastable/shape.cstr_broadcastable nary

This corresponds with the previous work to make shape.broadcast nary.
Additionally, simplify the ConvertShapeConstraints pass. It now doesn't
lower an implicit shape.is_broadcastable. This is still the same in
combination with shape-to-standard when the 2 passes are used in either
order.

Differential Revision: https://reviews.llvm.org/D96401

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
    mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
    mlir/test/Dialect/Shape/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 20b0706be367..b50a6f99e04c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -190,11 +190,12 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
   let assemblyFormat = "$input attr-dict `:` type($input)";
 }
 
-def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
-  let summary = "Determines if 2 shapes can be successfully broadcasted";
+def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
+                                       [Commutative, InferTypeOpInterface]> {
+  let summary = "Determines if 2+ shapes can be successfully broadcasted";
   let description = [{
-    Given two input shapes or extent tensors, return a predicate specifying if
-    they are broadcastable. This broadcastable follows the same logic as what
+    Given multiple input shapes or extent tensors, return a predicate specifying
+    if they are broadcastable. This broadcastable follows the same logic as what
     shape.broadcast documents.
 
     Concretely, shape.is_broadcastable returning true implies that
@@ -209,11 +210,28 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
     ```
   }];
 
-  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
-                       Shape_ShapeOrExtentTensorType:$rhs);
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
   let results = (outs I1:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  let builders = [
+  OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+    [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+  ];
+  let extraClassDeclaration = [{
+    // TODO: This should really be automatic. Figure out how to not need this defined.
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+      inferredReturnTypes.push_back(::mlir::IntegerType::get(context,
+                                                             /*width=*/1));
+      return success();
+    };
+  }];
+
+  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+  let verifier = [{ return ::verify(*this); }];
+
 }
 
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
@@ -692,11 +710,12 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
-def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
-  let summary = "Determines if 2 shapes can be successfully broadcasted";
+def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
+                                         [Commutative, InferTypeOpInterface]> {
+  let summary = "Determines if 2+ shapes can be successfully broadcasted";
   let description = [{
-    Given two input shapes or extent tensors, return a witness specifying if
-    they are broadcastable. This broadcastable follows the same logic as what
+    Given input shapes or extent tensors, return a witness specifying if they
+    are broadcastable. This broadcastable follows the same logic as what
     shape.broadcast documents.
 
     "cstr" operations represent runtime assertions.
@@ -708,14 +727,30 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
     ```
   }];
 
-  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
-                       Shape_ShapeOrExtentTensorType:$rhs);
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
   let results = (outs Shape_WitnessType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
+
+  let builders = [
+  OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs),
+    [{ build($_builder, $_state, ::llvm::makeArrayRef({lhs, rhs})); }]>,
+  ];
+
+  let extraClassDeclaration = [{
+    // TODO: This should really be automatic. Figure out how to not need this defined.
+    static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
+    ::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
+    ::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
+    ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
+      inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
+      return success();
+    };
+  }];
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let verifier = [{ return ::verify(*this); }];
 }
 
 def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index 65b1fa1096d6..e9d31ac93438 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -19,77 +19,8 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
-
 namespace {
-class ConvertCstrBroadcastableOp
-    : public OpRewritePattern<shape::CstrBroadcastableOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getType().isa<shape::ShapeType>() ||
-        op.lhs().getType().isa<shape::ShapeType>() ||
-        op.rhs().getType().isa<shape::ShapeType>()) {
-      return rewriter.notifyMatchFailure(
-          op, "cannot convert error-propagating shapes");
-    }
-
-    auto loc = op.getLoc();
-    Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-    Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-
-    // Find smaller and greater rank and extent tensor.
-    Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
-    Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
-    Value lhsRankULE =
-        rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
-    Type indexTy = rewriter.getIndexType();
-    Value lesserRank =
-        rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
-    Value greaterRank =
-        rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
-    Value lesserRankOperand =
-        rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
-    Value greaterRankOperand =
-        rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
-
-    Value rankDiff =
-        rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
-
-    // Generate code to compare the shapes extent by extent, and emit errors for
-    // non-broadcast-compatible shapes.
-    // Two extents are broadcast-compatible if
-    // 1. they are both equal, or
-    // 2. at least one of them is 1.
-
-    rewriter.create<scf::ForOp>(
-        loc, rankDiff, greaterRank, one, llvm::None,
-        [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
-          Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
-              loc, greaterRankOperand, ValueRange{iv});
-          Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
-          Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
-              loc, lesserRankOperand, ValueRange{ivShifted});
-
-          Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
-              loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
-          Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
-              loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
-          Value extentsAgree =
-              b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
-                               lesserRankOperandExtent);
-          auto broadcastIsValid =
-              b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
-                             b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
-                                            lesserRankOperandExtentIsOne));
-          b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
-          b.create<scf::YieldOp>(loc);
-        });
-
-    rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
-    return success();
-  }
-};
+#include "ShapeToStandard.cpp.inc"
 } // namespace
 
 namespace {
@@ -107,7 +38,7 @@ class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
 
 void mlir::populateConvertShapeConstraintsConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
-  patterns.insert<ConvertCstrBroadcastableOp>(ctx);
+  patterns.insert<CstrBroadcastableToRequire>(ctx);
   patterns.insert<ConvertCstrRequireOp>(ctx);
 }
 

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 3c83b4371df3..5f4396d73d88 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -237,63 +237,84 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
   // For now, this lowering is only defined on `tensor<?xindex>` operands, not
   // on shapes.
   IsBroadcastableOp::Adaptor transformed(operands);
-  if (transformed.lhs().getType().isa<ShapeType>() ||
-      transformed.rhs().getType().isa<ShapeType>())
+  if (!llvm::all_of(op.shapes(),
+                    [](Value v) { return !v.getType().isa<ShapeType>(); }))
     return failure();
 
   auto loc = op.getLoc();
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+  ImplicitLocOpBuilder lb(loc, rewriter);
+  Value zero = lb.create<ConstantIndexOp>(0);
+  Value one = lb.create<ConstantIndexOp>(1);
+  Type indexTy = lb.getIndexType();
+
+  // Save all the ranks for bounds checking. Because this is a tensor
+  // representing the shape extents, the rank is the extent of the only
+  // dimension in the tensor.
+  SmallVector<Value> ranks, rankDiffs;
+  llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+                       return lb.create<DimOp>(v, zero);
+                     }));
+
+  // Find the maximum rank
+  Value maxRank = ranks.front();
+  for (Value v : llvm::drop_begin(ranks, 1)) {
+    Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
+    maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
+  }
+
+  // Calculate the 
diff erence of ranks and the maximum rank for later offsets.
+  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
+                       return lb.create<SubIOp>(indexTy, maxRank, v);
+                     }));
 
-  // Find smaller and greater rank and extent tensor.
-  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
-  Value lhsRankULE =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
-  Type indexTy = rewriter.getIndexType();
-  Value lesserRank =
-      rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
-  Value greaterRank =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
-  auto erasedRankType =
-      RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
-  Value rankErasedLhs =
-      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
-  Value rankErasedRhs =
-      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
-  Value lesserRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
-  Value greaterRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
-  Value rankDiff =
-      rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
   Type i1Ty = rewriter.getI1Type();
-  Value init =
+  Value trueVal =
       rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
 
-  // Determine if all overlapping extents are broadcastable.
-  auto reduceResult = rewriter.create<ForOp>(
-      loc, rankDiff, greaterRank, one, ValueRange{init},
+  auto reduceResult = lb.create<ForOp>(
+      loc, zero, maxRank, one, ValueRange{trueVal},
       [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
-        Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
-            loc, greaterRankOperand, ValueRange{iv});
-        Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
-            loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
-        Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
-        Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
-            loc, lesserRankOperand, ValueRange{ivShifted});
-        Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
-            loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
-        Value extentsAreEqual =
-            b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
-                             lesserRankOperandExtent);
-        Value broadcastableExtents = b.create<AndOp>(
-            loc, iterArgs[0],
-            b.create<OrOp>(loc,
-                           b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
-                                          lesserRankOperandExtentIsOne),
-                           extentsAreEqual));
-        b.create<scf::YieldOp>(loc, broadcastableExtents);
+        // Find a non-1 dim, if it exists. Note that the first part of this
+        // could reuse the Broadcast lowering entirely, but we redo the work
+        // here to make optimizations easier between the two loops.
+        Value broadcastedDim = getBroadcastedDim(
+            ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv);
+
+        Value broadcastable = iterArgs[0];
+        for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) {
+          Value shape, rankDiff;
+          std::tie(shape, rankDiff) = tup;
+          Value outOfBounds =
+              b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff);
+          broadcastable =
+              b.create<IfOp>(
+                   loc, TypeRange{i1Ty}, outOfBounds,
+                   [&](OpBuilder &b, Location loc) {
+                     // Non existent dimensions are always broadcastable
+                     b.create<scf::YieldOp>(loc, broadcastable);
+                   },
+                   [&](OpBuilder &b, Location loc) {
+                     // Every value needs to be either 1, or the same non-1
+                     // value to be broadcastable in this dim.
+                     Value operandDimension =
+                         b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+                     Value dimensionExtent = b.create<tensor::ExtractOp>(
+                         loc, shape, ValueRange{operandDimension});
+
+                     Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+                                                       dimensionExtent, one);
+                     Value equalBroadcasted =
+                         b.create<CmpIOp>(loc, CmpIPredicate::eq,
+                                          dimensionExtent, broadcastedDim);
+                     Value result = b.create<AndOp>(
+                         loc, broadcastable,
+                         b.create<OrOp>(loc, equalOne, equalBroadcasted));
+                     b.create<scf::YieldOp>(loc, result);
+                   })
+                  .getResult(0);
+        }
+
+        b.create<scf::YieldOp>(loc, broadcastable);
       });
 
   rewriter.replaceOp(op, reduceResult.results().front());

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
index a5eaa7a2a889..aac3789c3b58 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.td
@@ -19,9 +19,9 @@ def BroadcastableStringAttr : NativeCodeCall<[{
   $_builder.getStringAttr("required broadcastable shapes")
 }]>;
 
-def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
+def CstrBroadcastableToRequire : Pat<(Shape_CstrBroadcastableOp $shapes),
             (Shape_CstrRequireOp
-              (Shape_IsBroadcastableOp $LHS, $RHS),
+              (Shape_IsBroadcastableOp $shapes),
               (BroadcastableStringAttr))>;
 
 #endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 8c75bdc9aa16..058c0c58dda2 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -491,6 +491,10 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
 }
 
 OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
+  // TODO: Add folding for the nary case
+  if (operands.size() != 2)
+    return nullptr;
+
   // Both operands are not needed if one is a scalar.
   if (operands[0] &&
       operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
@@ -512,9 +516,9 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
   // Lastly, see if folding can be completed based on what constraints are known
   // on the input shapes.
   SmallVector<int64_t, 6> lhsShape, rhsShape;
-  if (failed(getShapeVec(lhs(), lhsShape)))
+  if (failed(getShapeVec(shapes()[0], lhsShape)))
     return nullptr;
-  if (failed(getShapeVec(rhs(), rhsShape)))
+  if (failed(getShapeVec(shapes()[1], rhsShape)))
     return nullptr;
 
   if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
@@ -525,6 +529,13 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
   return nullptr;
 }
 
+static LogicalResult verify(CstrBroadcastableOp op) {
+  // Ensure that AssumingAllOp contains at least one operand
+  if (op.getNumOperands() < 2)
+    return op.emitOpError("required at least 2 input shapes");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CstrEqOp
 //===----------------------------------------------------------------------===//
@@ -723,6 +734,17 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// IsBroadcastableOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(IsBroadcastableOp op) {
+  // Ensure that AssumingAllOp contains at least one operand
+  if (op.getNumOperands() < 2)
+    return op.emitOpError("required at least 2 input shapes");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index 45c699baece3..8f847b1b28c5 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -18,8 +18,9 @@ def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
                            (replaceWithValue $args),
                            [(HasSingleElement $args)]>;
 
-def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
-  (Shape_ConstWitnessOp ConstBoolAttrTrue)>;
+def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $shapes),
+  (Shape_ConstWitnessOp ConstBoolAttrTrue),
+  [(AllInputShapesEq $shapes)]>;
 
 def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
   (Shape_ConstWitnessOp ConstBoolAttrTrue),

diff  --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 688b9fbffba7..5b47d9453261 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -4,28 +4,9 @@
 // CHECK-LABEL:   func @cstr_broadcastable(
 // CHECK-SAME:                             %[[LHS:.*]]: tensor<?xindex>,
 // CHECK-SAME:                             %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
-// CHECK:           %[[C0:.*]] = constant 0 : index
-// CHECK:           %[[C1:.*]] = constant 1 : index
 // CHECK:           %[[RET:.*]] = shape.const_witness true
-// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
-// CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
-// CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
-// CHECK:           scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
-// CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
-// CHECK:             %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
-// CHECK:             %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
-// CHECK:             %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-// CHECK:             %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-// CHECK:             %[[EXTENTS_AGREE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
-// CHECK:             %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1
-// CHECK:             %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1
-// CHECK:             assert %[[BROADCAST_IS_VALID]], "invalid broadcast"
-// CHECK:           }
+// CHECK:           %[[BROADCAST_IS_VALID:.*]] = shape.is_broadcastable %[[LHS]], %[[RHS]]
+// CHECK:           assert %[[BROADCAST_IS_VALID]], "required broadcastable shapes"
 // CHECK:           return %[[RET]] : !shape.witness
 // CHECK:         }
 func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 329e86848aa9..385e296177ad 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -305,77 +305,184 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
 
 // -----
 
-func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
-  %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
+func @try_is_broadcastable (%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> i1 {
+  %0 = shape.is_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
   return %0 : i1
 }
-
-// CHECK-LABEL:   func @try_is_broadcastable(
-// CHECK-SAME:        %[[LHS:.*]]: tensor<3xindex>,
-// CHECK-SAME:        %[[RHS:.*]]: tensor<?xindex>) -> i1 {
+// CHECK-LABEL: @try_is_broadcastable
+// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>)
 // CHECK:           %[[C0:.*]] = constant 0 : index
 // CHECK:           %[[C1:.*]] = constant 1 : index
-// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
-// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
-// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
-// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
-// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK:           %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK:           %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK:           %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK:           %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
 // CHECK:           %[[TRUE:.*]] = constant true
-// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
-// CHECK:             %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
-// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index
-// CHECK:             %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
-// CHECK:             %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
-// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index
-// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
-// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
-// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
-// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
-// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
-// CHECK:           }
-// CHECK:           return %[[ALL_RESULT]] : i1
-// CHECK:         }
+// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK:             %[[C1_0:.*]] = constant 1 : index
+// CHECK:             %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK:               scf.yield %[[C1_0]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK:               %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK:               scf.yield %[[DIM0]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK:               %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK:               scf.yield %[[DIM1]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK:               %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:             %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
+// CHECK:                scf.yield %[[ALL_SO_FAR]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:             %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
+// CHECK:                scf.yield %[[REDUCTION_0]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:             %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
+// CHECK:                scf.yield %[[SECOND_REDUCTION]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             scf.yield %[[FINAL_RESULT]] : i1
 
 // -----
 
-func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
-  %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
+func @broadcast(%a : tensor<2xindex>, %b : tensor<3xindex>, %c : tensor<2xindex>) -> !shape.witness {
+  %0 = shape.cstr_broadcastable %a, %b, %c : tensor<2xindex>, tensor<3xindex>, tensor<2xindex>
   return %0 : !shape.witness
 }
-
 // CHECK-LABEL:   func @broadcast(
-// CHECK-SAME:                    %[[LHS:.*]]: tensor<?xindex>,
-// CHECK-SAME:                    %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>)
 // CHECK:           %[[C0:.*]] = constant 0 : index
 // CHECK:           %[[C1:.*]] = constant 1 : index
-// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[LHS_SMALLER:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
-// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
-// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
-// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK:           %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK:           %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK:           %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK:           %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
 // CHECK:           %[[TRUE:.*]] = constant true
-// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
-// CHECK:             %[[LARGER_EXTENT:.*]] = tensor.extract %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
-// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[C1]] : index
-// CHECK:             %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
-// CHECK:             %[[SMALLER_EXTENT:.*]] = tensor.extract %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
-// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi eq, %[[SMALLER_EXTENT]], %[[C1]] : index
-// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi eq, %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
-// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
-// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
-// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
-// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
-// CHECK:           }
+// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[IDX:.*]] = %[[C0]] to %[[MAX_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK:             %[[C1_0:.*]] = constant 1 : index
+// CHECK:             %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK:               scf.yield %[[C1_0]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK:               %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1_0]], %[[EXTRACTED_0]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK:               scf.yield %[[DIM0]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK:               %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK:               scf.yield %[[DIM1]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK:               %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2]], %[[C1_0]] : index
+// CHECK:               %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:             %[[REDUCTION_0:.*]] = scf.if %[[OUT_BOUND_0]] -> (i1) {
+// CHECK:                scf.yield %[[ALL_SO_FAR]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg0[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[ALL_SO_FAR]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_1:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:             %[[SECOND_REDUCTION:.*]] = scf.if %[[OUT_BOUND_1]] -> (i1) {
+// CHECK:                scf.yield %[[REDUCTION_0]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg1[%[[SHIFTED]]] : tensor<3xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1]], %[[EQUALS_BROADCASTED]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[REDUCTION_0]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             %[[OUT_BOUND_2:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:             %[[FINAL_RESULT:.*]] = scf.if %[[OUT_BOUND_2]] -> (i1) {
+// CHECK:                scf.yield %[[SECOND_REDUCTION]] : i1
+// CHECK:             } else {
+// CHECK:                %[[SHIFTED:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:                %[[EXTRACTED:.*]] = tensor.extract %arg2[%[[SHIFTED]]] : tensor<2xindex>
+// CHECK:                %[[EQUALS_1:.*]] = cmpi eq, %[[EXTRACTED:.*]], %c1 : index
+// CHECK:                %[[EQUALS_BROADCASTED:.*]] = cmpi eq, %[[EXTRACTED:.*]], %[[DIM2]] : index
+// CHECK:                %[[GOOD:.*]] = or %[[EQUALS_1:.*]], %[[EQUALS_BROADCASTED:.*]] : i1
+// CHECK:                %[[AND_REDUCTION:.*]] = and %[[SECOND_REDUCTION]], %[[GOOD]] : i1
+// CHECK:                scf.yield %[[AND_REDUCTION]] : i1
+// CHECK:             }
+// CHECK:             scf.yield %[[FINAL_RESULT]] : i1
+
 // CHECK:           %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
 // CHECK:           return %[[RESULT]] : !shape.witness
 // CHECK:         }

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index d2f5af2f7b30..d685e6766072 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -246,3 +246,21 @@ func @fn(%arg: !shape.value_shape) -> !shape.shape {
 
 // expected-error at +1 {{@fn not found}}
 module attributes {shape.lib = @fn} { }
+
+// -----
+
+func @fn(%arg: !shape.shape) -> i1 {
+  // expected-error at +1 {{required at least 2 input shapes}}
+  %0 = shape.is_broadcastable %arg : !shape.shape
+  return %0 : i1
+}
+
+// -----
+
+func @fn(%arg: !shape.shape) -> !shape.witness {
+  // expected-error at +1 {{required at least 2 input shapes}}
+  %0 = shape.cstr_broadcastable %arg : !shape.shape
+  return %0 : !shape.witness
+}
+
+


        


More information about the Mlir-commits mailing list