[Mlir-commits] [mlir] 511484f - [mlir] Add lowering for IsBroadcastable to Std dialect.

Tres Popp llvmlistbot at llvm.org
Fri Oct 30 02:44:36 PDT 2020


Author: Tres Popp
Date: 2020-10-30T10:44:27+01:00
New Revision: 511484f27d923e16cbb88d3b00f800386a263cd5

URL: https://github.com/llvm/llvm-project/commit/511484f27d923e16cbb88d3b00f800386a263cd5
DIFF: https://github.com/llvm/llvm-project/commit/511484f27d923e16cbb88d3b00f800386a263cd5.diff

LOG: [mlir] Add lowering for IsBroadcastable to Std dialect.

Differential Revision: https://reviews.llvm.org/D90407

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0c300c4f9cc5..704b0cdb0324 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -207,6 +207,86 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite(
   return success();
 }
 
+namespace {
+struct IsBroadcastableOpConverter
+    : public OpConversionPattern<IsBroadcastableOp> {
+  using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
+    IsBroadcastableOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+  // on shapes.
+  IsBroadcastableOp::Adaptor transformed(operands);
+  if (transformed.lhs().getType().isa<ShapeType>() ||
+      transformed.rhs().getType().isa<ShapeType>())
+    return failure();
+
+  auto loc = op.getLoc();
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+  // Find smaller and greater rank and extent tensor.
+  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+  Value lhsRankULE =
+      rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+  Type indexTy = rewriter.getIndexType();
+  Value lesserRank =
+      rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+  Value greaterRank =
+      rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+  auto erasedRankType =
+      RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
+  Value rankErasedLhs =
+      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
+  Value rankErasedRhs =
+      rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
+  Value lesserRankOperand =
+      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
+  Value greaterRankOperand =
+      rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
+  Value rankDiff =
+      rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
+  Type i1Ty = rewriter.getI1Type();
+  Value init =
+      rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
+
+  // Determine if all overlapping extents are broadcastable.
+  auto reduceResult = rewriter.create<ForOp>(
+      loc, rankDiff, greaterRank, one, ValueRange{init},
+      [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
+        Value greaterRankOperandExtent =
+            b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
+        Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
+            loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
+        Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+        Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+            loc, lesserRankOperand, ValueRange{ivShifted});
+        Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
+            loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
+        Value extentsAreEqual =
+            b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
+                             lesserRankOperandExtent);
+        Value broadcastableExtents = b.create<AndOp>(
+            loc, iterArgs[0],
+            b.create<OrOp>(loc,
+                           b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
+                                          lesserRankOperandExtentIsOne),
+                           extentsAreEqual));
+        b.create<scf::YieldOp>(loc, broadcastableExtents);
+      });
+
+  rewriter.replaceOp(op, reduceResult.results().front());
+  return success();
+}
+
 namespace {
 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
   using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -522,6 +602,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       BroadcastOpConverter,
       ConstShapeOpConverter,
       ConstSizeOpConversion,
+      IsBroadcastableOpConverter,
       GetExtentOpConverter,
       RankOpConverter,
       ReduceOpConverter,

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index d827b6f4cc2e..56594d529e4d 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -382,3 +382,41 @@ func @broadcast_known_
diff erent_extents(%a : tensor<2xindex>, %b : tensor<3xinde
       : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
   return
 }
+
+// -----
+
+func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
+  %0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
+  return %0 : i1
+}
+
+// CHECK-LABEL:   func @try_is_broadcastable(
+// CHECK-SAME:        %[[LHS:.*]]: tensor<3xindex>,
+// CHECK-SAME:        %[[RHS:.*]]: tensor<?xindex>) -> i1 {
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
+// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK:           %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
+// CHECK:           %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK:           %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
+// CHECK:           %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
+// CHECK:           %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK:           %[[TRUE:.*]] = constant true
+// CHECK:           %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK:             %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
+// CHECK:             %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
+// CHECK:             %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
+// CHECK:             %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
+// CHECK:             %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
+// CHECK:             %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
+// CHECK:             %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
+// CHECK:             %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
+// CHECK:             %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
+// CHECK:             scf.yield %[[NEW_ALL_SO_FAR]] : i1
+// CHECK:           }
+// CHECK:           return %[[ALL_RESULT]] : i1
+// CHECK:         }


        


More information about the Mlir-commits mailing list