[Mlir-commits] [mlir] f30f347 - [mlir][shape] Generalize broadcast to a variadic number of shapes

Tres Popp llvmlistbot at llvm.org
Tue Feb 9 23:31:50 PST 2021


Author: Tres Popp
Date: 2021-02-10T08:31:28+01:00
New Revision: f30f347da1f8b9da231368f37538a8de49768d49

URL: https://github.com/llvm/llvm-project/commit/f30f347da1f8b9da231368f37538a8de49768d49
DIFF: https://github.com/llvm/llvm-project/commit/f30f347da1f8b9da231368f37538a8de49768d49.diff

LOG: [mlir][shape] Generalize broadcast to a variadic number of shapes

Previously broadcast was a binary op. Now it can support more inputs.
This has been changed in such a way that for now, this is an NFC for
all broadcast operations that were previously legal.

Differential Revision: https://reviews.llvm.org/D95777

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ba89a9455781..271a4f87eec9 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -50,12 +50,13 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
 }
 
 def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
-  let summary = "Returns the broadcasted output shape of two inputs";
+  let summary = "Returns the broadcasted output shape of two or more inputs";
   let description = [{
-    Returns the broadcasted shape for two input shapes or extent tensors. Both
-    operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
-    type `shape.shape` and, if both operands are tensors, may be of type
-    `tensor<?xindex>`.
+    Returns the broadcasted shape for input shapes or extent tensors. The rest
+    of this description is simplified for the 2 input case but can be extended
+    to more inputs. Both operands can be of type `shape.shape` or
+    `tensor<?xindex>`. The result is of type `shape.shape` and, if both
+    operands are tensors, may be of type `tensor<?xindex>`.
 
     If the two operand shapes are of 
diff erent rank the smaller one is padded
     with 1's from the left. The resulting broadcasted shape is then defined as
@@ -72,19 +73,26 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
     attribute can be used to describe the error case.
   }];
 
-  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
-                       Shape_ShapeOrExtentTensorType:$rhs,
+  let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes,
                        OptionalAttr<StrAttr>:$error);
   let results = (outs Shape_ShapeOrExtentTensorType:$result);
 
   let assemblyFormat = [{
-    $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
+    $shapes attr-dict `:` type($shapes) `->` type($result)
   }];
 
-  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
-  let hasFolder = 1;
+  let builders = [OpBuilderDAG<(ins "::mlir::Type":$result,
+                                "::mlir::Value":$lhs, "::mlir::Value":$rhs,
+                                "/*optional*/ ::mlir::StringAttr":$error), [{
+      build($_builder, $_state, result, ::llvm::makeArrayRef({lhs, rhs}), error);
+    }]>
+  ];
 
-  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
+  let hasFolder = 1;
+  let verifier = [{
+    return success(succeeded(::verifyShapeOrExtentTensorOp(*this)) &&
+                   getNumOperands() >= 2);
+  }];
 }
 
 def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0eeea250f19f..3c83b4371df3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -14,7 +14,9 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 using namespace mlir::shape;
@@ -73,6 +75,48 @@ struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
   matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override;
 };
+
+// Get the resulting extent in a given dimension. This is computed with any
+// number of extent tensors and shifted offsets into them.
+Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
+                        ValueRange rankDiffs, Value outputDimension) {
+  Value one = lb.create<ConstantIndexOp>(1);
+  Value broadcastedDim = one;
+  for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
+    Value shape = std::get<0>(tup);
+    Value rankDiff = std::get<1>(tup);
+    Value outOfBounds =
+        lb.create<CmpIOp>(CmpIPredicate::ult, outputDimension, rankDiff);
+    Type indexTy = lb.getIndexType();
+    broadcastedDim =
+        lb.create<IfOp>(
+              TypeRange{indexTy}, outOfBounds,
+              [&](OpBuilder &b, Location loc) {
+                b.create<scf::YieldOp>(loc, broadcastedDim);
+              },
+              [&](OpBuilder &b, Location loc) {
+                // The broadcasting logic is:
+                // - if one extent (here we arbitrarily choose the
+                // extent from the greater-rank operand) is equal to 1,
+                // then take the extent from the other operand
+                // - otherwise, take the extent as-is.
+                // Note that this logic remains correct in the presence
+                // of dimensions of zero extent.
+                Value lesserRankOperandDimension =
+                    b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
+                Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
+                    loc, shape, ValueRange{lesserRankOperandDimension});
+
+                Value dimIsOne = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+                                                  lesserRankOperandExtent, one);
+                Value dim = b.create<SelectOp>(loc, dimIsOne, broadcastedDim,
+                                               lesserRankOperandExtent);
+                b.create<scf::YieldOp>(loc, dim);
+              })
+            .getResult(0);
+  }
+  return broadcastedDim;
+}
 } // namespace
 
 LogicalResult BroadcastOpConverter::matchAndRewrite(
@@ -83,76 +127,44 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   if (op.getType().isa<ShapeType>())
     return failure();
 
-  assert(!op.lhs().getType().isa<ShapeType>() &&
-         !op.rhs().getType().isa<ShapeType>());
   auto loc = op.getLoc();
+  ImplicitLocOpBuilder lb(loc, rewriter);
   BroadcastOp::Adaptor transformed(operands);
-  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
-  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
 
-  // Find smaller and greater rank and extent tensor.
-  Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
-  Value lhsRankULE =
-      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
-  Type indexTy = rewriter.getIndexType();
-  Value lesserRank =
-      rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
-  Value greaterRank =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
-  auto erasedRankType =
-      RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
-  Value rankErasedLhs =
-      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs());
-  Value rankErasedRhs =
-      rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs());
-  Value lesserRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
-  Value greaterRankOperand =
-      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
+  Value zero = lb.create<ConstantIndexOp>(0);
+  Type indexTy = lb.getIndexType();
+
+  // Save all the ranks for bounds checking. Because this is a tensor
+  // representing the shape extents, the rank is the extent of the only
+  // dimension in the tensor.
+  SmallVector<Value> ranks, rankDiffs;
+  llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) {
+                       return lb.create<DimOp>(v, zero);
+                     }));
+
+  // Find the maximum rank
+  Value maxRank = ranks.front();
+  for (Value v : llvm::drop_begin(ranks, 1)) {
+    Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank);
+    maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank);
+  }
 
-  Value rankDiff =
-      rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
-  rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
-      op, getExtentTensorType(op.getContext()), ValueRange{greaterRank},
-      [&](OpBuilder &b, Location loc, ValueRange args) {
-        Value outputDimension = args[0];
-        Value isUnchallengedDimension = b.create<CmpIOp>(
-            loc, CmpIPredicate::ult, outputDimension, rankDiff);
-        Value greaterRankOperandExtent = b.create<tensor::ExtractOp>(
-            loc, greaterRankOperand, outputDimension);
-        // The initial dimensions of the greater-rank operand are unchallenged,
-        // so we can take them as-is. Otherwise, we need to do a comparison.
-        // We need an actual branch here (instead of a select) because the
-        // lesser-rank operand might be rank 0, so any tensor.extract would be
-        // invalid.
-        auto ifOp = b.create<IfOp>(
-            loc, TypeRange{indexTy}, isUnchallengedDimension,
-            [&](OpBuilder &b, Location loc) {
-              b.create<scf::YieldOp>(loc, greaterRankOperandExtent);
-            },
-            [&](OpBuilder &b, Location loc) {
-              // The broadcasting logic is:
-              // - if one extent (here we arbitrarily choose the extent from
-              // the greater-rank operand) is equal to 1, then take the extent
-              // from the other operand
-              // - otherwise, take the extent as-is.
-              // Note that this logic remains correct in the presence of
-              // dimensions of zero extent.
-              Value lesserRankOperandDimension =
-                  b.create<SubIOp>(loc, indexTy, outputDimension, rankDiff);
-              Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
-                  loc, lesserRankOperand,
-                  ValueRange{lesserRankOperandDimension});
-              Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
-                  loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
-              Value broadcastedExtent = b.create<SelectOp>(
-                  loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent,
-                  greaterRankOperandExtent);
-              b.create<scf::YieldOp>(loc, broadcastedExtent);
-            });
-        b.create<tensor::YieldOp>(loc, ifOp.getResult(0));
-      });
+  // Calculate the 
diff erence of ranks and the maximum rank for later offsets.
+  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
+                       return lb.create<SubIOp>(indexTy, maxRank, v);
+                     }));
+
+  rewriter.replaceOp(
+      op, lb.create<tensor::GenerateOp>(
+                getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+                [&](OpBuilder &b, Location loc, ValueRange args) {
+                  Value broadcastedDim = getBroadcastedDim(
+                      ImplicitLocOpBuilder(loc, b), transformed.shapes(),
+                      rankDiffs, args[0]);
+
+                  b.create<tensor::YieldOp>(loc, broadcastedDim);
+                })
+              ->getResults());
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65ebc54aeeb3..9657f9566ea6 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -357,10 +357,14 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[1])
     return nullptr;
 
+  // TODO: Support folding with more than 2 input shapes
+  if (operands.size() > 2 && !operands[2].isa<StringAttr>())
+    return nullptr;
+
   auto rhsShape = llvm::to_vector<6>(
       operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
   if (rhsShape.empty())
-    return lhs();
+    return shapes()[0];
 
   if (!operands[0])
     return nullptr;
@@ -368,7 +372,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   auto lhsShape = llvm::to_vector<6>(
       operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
   if (lhsShape.empty())
-    return rhs();
+    return shapes()[1];
 
   SmallVector<int64_t, 6> resultShape;
   // If the shapes are not compatible, we can't fold it.

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 2bd4a1d34901..329e86848aa9 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -305,86 +305,6 @@ func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
 
 // -----
 
-// CHECK-LABEL:   func @broadcast_unknown_extents(
-// CHECK-SAME:                                    %[[LHS:.*]]: tensor<?xindex>,
-// CHECK-SAME:                                    %[[RHS:.*]]: tensor<?xindex>) {
-func @broadcast_unknown_extents(%a : tensor<?xindex>, %b : tensor<?xindex>) {
-  // CHECK:           %[[C0:.*]] = constant 0 : index
-  // CHECK:           %[[C1:.*]] = constant 1 : index
-  // CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-  // CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-  // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-  // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-  // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-  // CHECK:           %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
-  // CHECK:           %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
-  // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
-  // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
-  // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
-  // CHECK:           %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
-  // CHECK:           ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
-  // CHECK:             %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
-  // CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
-  // CHECK:             %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
-  // CHECK:               scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
-  // CHECK:             } else {
-  // CHECK:               %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
-  // CHECK:               %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
-  // CHECK:               %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-  // CHECK:               %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
-  // CHECK:               scf.yield %[[BROADCASTED_EXTENT]] : index
-  // CHECK:             }
-  // CHECK:             yield %[[OUTPUT_EXTENT:.*]] : index
-  // CHECK:           } : tensor<?xindex>
-  // CHECK:           return
-  // CHECK:         }
-  %0 = shape.broadcast %a, %b
-      : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
-  return
-}
-
-// -----
-
-// CHECK-LABEL:   func @broadcast_known_
diff erent_extents(
-// CHECK-SAME:                                            %[[LHS:.*]]: tensor<2xindex>,
-// CHECK-SAME:                                            %[[RHS:.*]]: tensor<3xindex>) {
-func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) {
-  // CHECK:           %[[C0:.*]] = constant 0 : index
-  // CHECK:           %[[C1:.*]] = constant 1 : index
-  // CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex>
-  // CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex>
-  // CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi ule, %[[LHS_RANK]], %[[RHS_RANK]] : index
-  // CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
-  // CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
-  // CHECK:           %[[ERASED_LHS:.*]] = tensor.cast %[[LHS]] : tensor<2xindex> to tensor<?xindex>
-  // CHECK:           %[[ERASED_RHS:.*]] = tensor.cast %[[RHS]] : tensor<3xindex> to tensor<?xindex>
-  // CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor<?xindex>
-  // CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor<?xindex>
-  // CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
-  // CHECK:           %[[RESULT:.*]] = tensor.generate %[[GREATER_RANK]] {
-  // CHECK:           ^bb0(%[[OUTPUT_DIMENSION:.*]]: index):
-  // CHECK:             %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi ult, %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
-  // CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor<?xindex>
-  // CHECK:             %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) {
-  // CHECK:               scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
-  // CHECK:             } else {
-  // CHECK:               %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index
-  // CHECK:               %[[LESSER_RANK_OPERAND_EXTENT:.*]] = tensor.extract %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor<?xindex>
-  // CHECK:               %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi eq, %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
-  // CHECK:               %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index
-  // CHECK:               scf.yield %[[BROADCASTED_EXTENT]] : index
-  // CHECK:             }
-  // CHECK:             yield %[[OUTPUT_EXTENT:.*]] : index
-  // CHECK:           } : tensor<?xindex>
-  // CHECK:           return
-  // CHECK:         }
-  %0 = shape.broadcast %a, %b
-      : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
-  return
-}
-
-// -----
-
 func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
   %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
   return %0 : i1
@@ -459,3 +379,62 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
 // CHECK:           %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
 // CHECK:           return %[[RESULT]] : !shape.witness
 // CHECK:         }
+
+// -----
+
+func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
+                                           %b : tensor<3xindex>,
+                                           %c : tensor<2xindex>) {
+// CHECK-LABEL:   func @broadcast_3_shapes_
diff erent_extents(
+// CHECK-SAME:          %[[ARG0:.*]]: tensor<2xindex>,
+// CHECK-SAME:          %[[ARG1:.*]]: tensor<3xindex>,
+// CHECK-SAME:          %[[ARG2:.*]]: tensor<2xindex>) {
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[RANK0:.*]] = dim %[[ARG0]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[RANK1:.*]] = dim %[[ARG1]], %[[C0]] : tensor<3xindex>
+// CHECK:           %[[RANK2:.*]] = dim %[[ARG2]], %[[C0]] : tensor<2xindex>
+// CHECK:           %[[CMP0:.*]] = cmpi ugt, %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[LARGER_DIM:.*]] = select %[[CMP0]], %[[RANK1]], %[[RANK0]] : index
+// CHECK:           %[[CMP1:.*]] = cmpi ugt, %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[MAX_RANK:.*]] = select %[[CMP1]], %[[RANK2]], %[[LARGER_DIM]] : index
+// CHECK:           %[[DIM_DIFF0:.*]] = subi %[[MAX_RANK]], %[[RANK0]] : index
+// CHECK:           %[[DIM_DIFF1:.*]] = subi %[[MAX_RANK]], %[[RANK1]] : index
+// CHECK:           %[[DIM_DIFF2:.*]] = subi %[[MAX_RANK]], %[[RANK2]] : index
+// CHECK:           %[[RESULT:.*]] = tensor.generate %[[MAX_RANK]]  {
+// CHECK:           ^bb0(%[[IDX:.*]]: index):
+// CHECK:             %[[C1:.*]] = constant 1 : index
+// CHECK:             %[[OUTBOUNDS0:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:             %[[DIM0:.*]] = scf.if %[[OUTBOUNDS0]] -> (index) {
+// CHECK:               scf.yield %[[C1]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX0:.*]] = subi %[[IDX]], %[[DIM_DIFF0]] : index
+// CHECK:               %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[IDX0]]] : tensor<2xindex>
+// CHECK:               %[[DIM0_IS_1:.*]] = cmpi eq, %[[EXTRACTED_0:.*]], %[[C1]] : index
+// CHECK:               %[[MAX_DIM0:.*]] = select %[[DIM0_IS_1]], %[[C1]], %[[EXTRACTED_0]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_28:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:             %[[DIM1:.*]] = scf.if %[[VAL_28]] -> (index) {
+// CHECK:               scf.yield %[[DIM0]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX1:.*]] = subi %[[IDX]], %[[DIM_DIFF1]] : index
+// CHECK:               %[[EXTRACTED_1:.*]] = tensor.extract %[[ARG1]]{{\[}}%[[IDX1]]] : tensor<3xindex>
+// CHECK:               %[[DIM1_IS_1:.*]] = cmpi eq, %[[EXTRACTED_1:.*]], %[[C1]] : index
+// CHECK:               %[[MAX_DIM1:.*]] = select %[[DIM1_IS_1]], %[[DIM0]], %[[EXTRACTED_1]] : index
+// CHECK:             }
+// CHECK:             %[[VAL_36:.*]] = cmpi ult, %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:             %[[DIM2:.*]] = scf.if %[[VAL_36]] -> (index) {
+// CHECK:               scf.yield %[[DIM1]] : index
+// CHECK:             } else {
+// CHECK:               %[[IDX2:.*]] = subi %[[IDX]], %[[DIM_DIFF2]] : index
+// CHECK:               %[[EXTRACTED_2:.*]] = tensor.extract %[[ARG2]]{{\[}}%[[IDX2]]] : tensor<2xindex>
+// CHECK:               %[[DIM2_IS_1:.*]] = cmpi eq, %[[EXTRACTED_2:.*]], %[[C1]] : index
+// CHECK:               %[[MAX_DIM2:.*]] = select %[[DIM2_IS_1]], %[[DIM1]], %[[EXTRACTED_2]] : index
+// CHECK:             }
+// CHECK:             tensor.yield %[[DIM2]] : index
+// CHECK:           } : tensor<?xindex>
+// CHECK:           return
+// CHECK:         }
+  %0 = shape.broadcast %a, %b, %c
+      : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
+  return
+}


        


More information about the Mlir-commits mailing list