[Mlir-commits] [mlir] f30f347 - [mlir][shape] Generalize broadcast to a variadic number of shapes
Tres Popp
llvmlistbot at llvm.org
Tue Feb 9 23:31:50 PST 2021
Author: Tres Popp
Date: 2021-02-10T08:31:28+01:00
New Revision: f30f347da1f8b9da231368f37538a8de49768d49
URL: https://github.com/llvm/llvm-project/commit/f30f347da1f8b9da231368f37538a8de49768d49
DIFF: https://github.com/llvm/llvm-project/commit/f30f347da1f8b9da231368f37538a8de49768d49.diff
LOG: [mlir][shape] Generalize broadcast to a variadic number of shapes
Previously broadcast was a binary op. Now it can support more inputs.
This has been changed in such a way that for now, this is an NFC for
all broadcast operations that were previously legal.
Differential Revision: https://reviews.llvm.org/D95777
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ba89a9455781..271a4f87eec9 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -50,12 +50,13 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
}
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
- let summary = "Returns the broadcasted output shape of two inputs";
+ let summary = "Returns the broadcasted output shape of two or more inputs";
let description = [{
- Returns the broadcasted shape for two input shapes or extent tensors. Both
- operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
- type `shape.shape` and, if both operands are tensors, may be of type
- `tensor<?xindex>`.
+ Returns the broadcasted shape for input shapes or extent tensors. The rest
+ of this description is simplified for the 2 input case but can be extended
+ to more inputs. Both operands can be of type `shape.shape` or
+ `tensor<?xindex>`. The result is of type `shape.shape` and, if both
+ operands are tensors, may be of type `tensor<?xindex>`.
If the two operand shapes are of
diff erent rank the smaller one is padded
with 1's from the left. The resulting broadcasted shape is then defined as
@@ -72,19 +73,26 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
attribute can be used to describe the error case.
}];
- let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
- Shape_ShapeOrExtentTensorType:$rhs,
+ let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrExtentTensorType:$result);
let assemblyFormat = [{
- $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+ $shapes attr-dict `:` type($shapes) `->` type($result)
}];
- let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
- let hasFolder = 1;
+ let builders = [OpBuilderDAG<(ins "::mlir::Type":$result,
+ "::mlir::Value":$lhs, "::mlir::Value":$rhs,
+ "/*optional*/ ::mlir::StringAttr":$error), [{
+ build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error);
+ }]>
+ ];
- let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
+ let hasFolder = 1;
+ let verifier = [{
+ return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) &&
+ getNumOperands() >= 2);
+ }];
}
def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0eeea250f19f..3c83b4371df3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -14,7 +14,9 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::shape;
@@ -73,6 +75,48 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
+
+// Get the resulting extent in a given dimension. This is computed with any
+// number of extent tensors and shifted offsets into them.
+Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
+ ValueRange rankDiffs, Value outputDimension) {
+ Value one = lb.create<ConstantIndexOp>(1);
+ Value broadcastedDim = one;
+ for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
+ Value shape = std::get<0>(tup);
+ Value rankDiff = std::get<1>(tup);
+ Value outOfBounds =
+ lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
+ Type indexTy = lb.getIndexType();
+ broadcastedDim =
+ lb.create<IfOp>(
+ TypeRange{indexTy}, outOfBounds,
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, broadcastedDim);
+ },
+ [&](OpBuilder &b, Location loc) {
+ // The broadcasting logic is:
+ // - if one extent (here we arbitrarily choose the
+ // extent from the greater-rank operand) is equal to 1,
+ // then take the extent from the other operand
+ // - otherwise, take the extent as-is.
+ // Note that this logic remains correct in the presence
+ // of dimensions of zero extent.
+ Value lesserRankOperandDimension =
+ b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
+ Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
+ loc, shape, ValueRange{lesserRankOperandDimension});
+
+ Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+ lesserRankOperandExtent, one);
+ Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
+ lesserRankOperandExtent);
+ b.create<scf::YieldOp>(loc, dim);
+ })
+ .getResult(0);
+ }
+ return broadcastedDim;
+}
} // namespace
LogicalResult BroadcastOpConverter::matchAndRewrite(
@@ -83,76 +127,44 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
if (op.getType().isa<ShapeType>())
return failure();
- assert(!op.lhs().getType().isa<ShapeType>() &&
- !op.rhs().getType().isa<ShapeType>());
auto loc = op.getLoc();
+ ImplicitLocOpBuilder lb(loc, rewriter);
BroadcastOp::Adaptor transformed(operands);
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- // Find smaller and greater rank and extent tensor.
- Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
- Value lhsRankULE =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
- Type indexTy = rewriter.getIndexType();
- Value lesserRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
- Value greaterRank =
- rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
- auto erasedRankType =
- RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
- Value rankErasedLhs =
- rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
- Value rankErasedRhs =
- rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
- Value lesserRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
- Value greaterRankOperand =
- rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
+ Value zero = lb.create<ConstantIndexOp>(0);
+ Type indexTy = lb.getIndexType();
+
+ // Save all the ranks for bounds checking. Because this is a tensor
+ // representing the shape extents, the rank is the extent of the only
+ // dimension in the tensor.
+ SmallVector<Value> ranks, rankDiffs;
+ llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+ return lb.create<DimOp>(v, zero);
+ }));
+
+ // Find the maximum rank
+ Value maxRank = ranks.front();
+ for (Value v : llvm::drop_begin(ranks, 1)) {
+ Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
+ maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
+ }
- Value rankDiff =
- rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
- rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
- op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value outputDimension = args[0];
- Value isUnchallengedDimension = b.create<CmpIOp>(
- loc, CmpIPredicate::ult, outputDimension, rankDiff);
- Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, greaterRankOperand, outputDimension);
- // The initial dimensions of the greater-rank operand are unchallenged,
- // so we can take them as-is. Otherwise, we need to do a comparison.
- // We need an actual branch here (instead of a select) because the
- // lesser-rank operand might be rank 0, so any tensor.extract would be
- // invalid.
- auto ifOp = b.create<IfOp>(
- loc, TypeRange{indexTy}, isUnchallengedDimension,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
- },
- [&](OpBuilder &b, Location loc) {
- // The broadcasting logic is:
- // - if one extent (here we arbitrarily choose the extent from
- // the greater-rank operand) is equal to 1, then take the extent
- // from the other operand
- // - otherwise, take the extent as-is.
- // Note that this logic remains correct in the presence of
- // dimensions of zero extent.
- Value lesserRankOperandDimension =
- b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
- Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
- loc, lesserRankOperand,
- ValueRange{lesserRankOperandDimension});
- Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
- loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
- Value broadcastedExtent = b.create<SelectOp>(
- loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
- greaterRankOperandExtent);
- b.create<scf::YieldOp>(loc, broadcastedExtent);
- });
- b.create<tensor::YieldOp>(loc, ifOp.getResult(0));
- });
+ // Calculate the
diff erence of ranks and the maximum rank for later offsets.
+ llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
+ return lb.create<SubIOp>(indexTy, maxRank, v);
+ }));
+
+ rewriter.replaceOp(
+ op, lb.create<tensor::GenerateOp>(
+ getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value broadcastedDim = getBroadcastedDim(
+ ImplicitLocOpBuilder(loc, b), transformed.shapes(),
+ rankDiffs, args[0]);
+
+ b.create<tensor::YieldOp>(loc, broadcastedDim);
+ })
+ ->getResults());
return success();
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65ebc54aeeb3..9657f9566ea6 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -357,10 +357,14 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (!operands[1])
return nullptr;
+ // TODO: Support folding with more than 2 input shapes
+ if (operands.size() > 2 && !operands[2].isa<StringAttr>())
+ return nullptr;
+
auto rhsShape = llvm::to_vector<6>(
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (rhsShape.empty())
- return lhs();
+ return shapes()[0];
if (!operands[0])
return nullptr;
@@ -368,7 +372,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto lhsShape = llvm::to_vector<6>(
operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
if (lhsShape.empty())
- return rhs();
+ return shapes()[1];
SmallVector<int64_t, 6> resultShape;
// If the shapes are not compatible, we can't fold it.
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 2bd4a1d34901..329e86848aa9 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -305,86 +305,6 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
// -----
-// CHECK-LABEL: func @broadcast_unknown_extents(
-// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) {
-func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[C1:.*]] = constant 1 : index
- // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
- // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
- // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
- // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
- // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
- // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
- // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
- // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
- // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
- // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
- // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
- // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
- // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
- // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
- // CHECK: } else {
- // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
- // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
- // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
- // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
- // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
- // CHECK: }
- // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
- // CHECK: } : tensor<?xindex>
- // CHECK: return
- // CHECK: }
- %0 = shape.broadcast %a, %b
- : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
- return
-}
-
-// -----
-
-// CHECK-LABEL: func @broadcast_known_
diff erent_extents(
-// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>,
-// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) {
-func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[C1:.*]] = constant 1 : index
- // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
- // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
- // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
- // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
- // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
- // CHECK: %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
- // CHECK: %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
- // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
- // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
- // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
- // CHECK: %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
- // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
- // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
- // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
- // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
- // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
- // CHECK: } else {
- // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
- // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
- // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
- // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
- // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index
- // CHECK: }
- // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index
- // CHECK: } : tensor<?xindex>
- // CHECK: return
- // CHECK: }
- %0 = shape.broadcast %a, %b
- : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
- return
-}
-
-// -----
-
func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
%0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
return %0 : i1
@@ -459,3 +379,62 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
// CHECK: return %[[RESULT]] : !shape.witness
// CHECK: }
+
+// -----
+
+func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
+ %b : tensor<3xindex>,
+ %c : tensor<2xindex>) {
+// CHECK-LABEL: func @broadcast_3_shapes_
diff erent_extents(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2xindex>) {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK: %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK: %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK: %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK: %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK: %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK: %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
+// CHECK: %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]] {
+// CHECK: ^bb0(%[[IDX:.*]]: index):
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK: scf.yield %[[C1]] : index
+// CHECK: } else {
+// CHECK: %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK: %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index
+// CHECK: %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index
+// CHECK: }
+// CHECK: %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK: scf.yield %[[DIM0]] : index
+// CHECK: } else {
+// CHECK: %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK: %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index
+// CHECK: %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK: }
+// CHECK: %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK: scf.yield %[[DIM1]] : index
+// CHECK: } else {
+// CHECK: %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK: %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index
+// CHECK: %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK: }
+// CHECK: tensor.yield %[[DIM2]] : index
+// CHECK: } : tensor<?xindex>
+// CHECK: return
+// CHECK: }
+ %0 = shape.broadcast %a, %b, %c
+ : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
+ return
+}
More information about the Mlir-commits
mailing list