[Mlir-commits] [mlir] a975be0 - [mlir][shape] Make conversion passes more consistent.
Sean Silva
llvmlistbot at llvm.org
Mon Sep 28 15:01:36 PDT 2020
Author: Sean Silva
Date: 2020-09-28T14:55:42-07:00
New Revision: a975be0e00a12fdf09ffc9127825321c79813f33
URL: https://github.com/llvm/llvm-project/commit/a975be0e00a12fdf09ffc9127825321c79813f33
DIFF: https://github.com/llvm/llvm-project/commit/a975be0e00a12fdf09ffc9127825321c79813f33.diff
LOG: [mlir][shape] Make conversion passes more consistent.
- use select-ops to make the lowering simpler
- change style of FileCheck variables names to be consistent
- change some variable names in the code to be more explicit
Differential Revision: https://reviews.llvm.org/D88258
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index d4aad9d904ca..3ea6233700c3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -39,25 +39,17 @@ class ConvertCstrBroadcastableOp
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
- Value lhsSmaller =
+ Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
- Type extentTensorTy = op.lhs().getType();
- auto ifOp = rewriter.create<scf::IfOp>(
- loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
- lhsSmaller,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(
- loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
- },
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(
- loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
- });
- Value lesserRank = ifOp.getResult(0);
- Value lesserRankOperand = ifOp.getResult(1);
- Value greaterRank = ifOp.getResult(2);
- Value greaterRankOperand = ifOp.getResult(3);
+ Value lesserRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+ Value greaterRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+ Value lesserRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
+ Value greaterRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0a6953842a14..b1319a8cd386 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -90,27 +90,19 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
- Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
- Value lhsSmaller =
+ Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
+ Value lhsRankULE =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
- Type extentTensorTy = op.getType();
- auto ifOp = rewriter.create<IfOp>(
- loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
- lhsSmaller,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
- rhsRank, transformed.rhs()});
- },
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
- lhsRank, transformed.lhs()});
- });
- Value smallerRank = ifOp.getResult(0);
- Value smallerOperand = ifOp.getResult(1);
- Value greaterRank = ifOp.getResult(2);
- Value greaterOperand = ifOp.getResult(3);
+ Value lesserRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+ Value greaterRank =
+ rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+ Value lesserRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
+ Value greaterRankOperand =
+ rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
// Allocate stack memory for the broadcasted extent tensor.
Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
@@ -118,11 +110,11 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
// Copy extents from greater operand that are not challenged.
Value rankDiff =
- rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+ rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value extent = b.create<ExtractElementOp>(
- loc, greaterOperand, ValueRange{iv});
+ loc, greaterRankOperand, ValueRange{iv});
b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
b.create<scf::YieldOp>(loc);
});
@@ -132,16 +124,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterOperandExtent =
- b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+ b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
Value greaterOperandExtentIsOne =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
auto ifOp = b.create<IfOp>(
loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
[&](OpBuilder &b, Location loc) {
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
- Value smallerOperandExtent = b.create<ExtractElementOp>(
- loc, smallerOperand, ValueRange{ivShifted});
- b.create<scf::YieldOp>(loc, smallerOperandExtent);
+ Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+ loc, lesserRankOperand, ValueRange{ivShifted});
+ b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, greaterOperandExtent);
diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 1f7b6d60dd4f..72349d5e44af 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -7,25 +7,24 @@
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[RET:.*]] = shape.const_witness true
-// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
-// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
-// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-// CHECK: } else {
-// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-// CHECK: }
-// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
-// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
-// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
-// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
-// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
-// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
-// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
-// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
-// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
-// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
-// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
+// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
+// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
+// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
+// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
+// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
+// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
+// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
+// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
+// CHECK: %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1
+// CHECK: %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1
+// CHECK: assert %[[BROADCAST_IS_VALID]], "invalid broadcast"
// CHECK: }
// CHECK: return %[[RET]] : !shape.witness
// CHECK: }
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 01ba6abcc6c4..6207486db821 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -312,27 +312,26 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
- // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
- // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
- // CHECK: } else {
- // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
- // CHECK: }
- // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
- // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+ // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+ // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+ // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+ // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
+ // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
+ // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
+ // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
- // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
- // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
- // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
- // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+ // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
+ // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+ // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
// CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
// CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
- // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
- // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+ // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
+ // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
// CHECK: } else {
- // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+ // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
// CHECK: }
// CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
// CHECK: }
@@ -341,4 +340,3 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
return
}
-
More information about the Mlir-commits
mailing list