[Mlir-commits] [mlir] a975be0 - [mlir][shape] Make conversion passes more consistent.

Sean Silva llvmlistbot at llvm.org
Mon Sep 28 15:01:36 PDT 2020


Author: Sean Silva
Date: 2020-09-28T14:55:42-07:00
New Revision: a975be0e00a12fdf09ffc9127825321c79813f33

URL: https://github.com/llvm/llvm-project/commit/a975be0e00a12fdf09ffc9127825321c79813f33
DIFF: https://github.com/llvm/llvm-project/commit/a975be0e00a12fdf09ffc9127825321c79813f33.diff

LOG: [mlir][shape] Make conversion passes more consistent.

- use select-ops to make the lowering simpler
- change style of FileCheck variables names to be consistent
- change some variable names in the code to be more explicit

Differential Revision: https://reviews.llvm.org/D88258

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index d4aad9d904ca..3ea6233700c3 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -39,25 +39,17 @@ class ConvertCstrBroadcastableOp
     // Find smaller and greater rank and extent tensor.
     Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
     Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
-    Value lhsSmaller =
+    Value lhsRankULE =
         rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
     Type indexTy = rewriter.getIndexType();
-    Type extentTensorTy = op.lhs().getType();
-    auto ifOp = rewriter.create<scf::IfOp>(
-        loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
-        lhsSmaller,
-        [&](OpBuilder &b, Location loc) {
-          b.create<scf::YieldOp>(
-              loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
-        },
-        [&](OpBuilder &b, Location loc) {
-          b.create<scf::YieldOp>(
-              loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
-        });
-    Value lesserRank = ifOp.getResult(0);
-    Value lesserRankOperand = ifOp.getResult(1);
-    Value greaterRank = ifOp.getResult(2);
-    Value greaterRankOperand = ifOp.getResult(3);
+    Value lesserRank =
+        rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+    Value greaterRank =
+        rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+    Value lesserRankOperand =
+        rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
+    Value greaterRankOperand =
+        rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
 
     Value rankDiff =
         rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);

diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0a6953842a14..b1319a8cd386 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -90,27 +90,19 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
 
   // Find smaller and greater rank and extent tensor.
-  Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
-  Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
-  Value lhsSmaller =
+  Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
+  Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
+  Value lhsRankULE =
       rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
   Type indexTy = rewriter.getIndexType();
-  Type extentTensorTy = op.getType();
-  auto ifOp = rewriter.create<IfOp>(
-      loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
-      lhsSmaller,
-      [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
-                                               rhsRank, transformed.rhs()});
-      },
-      [&](OpBuilder &b, Location loc) {
-        b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
-                                               lhsRank, transformed.lhs()});
-      });
-  Value smallerRank = ifOp.getResult(0);
-  Value smallerOperand = ifOp.getResult(1);
-  Value greaterRank = ifOp.getResult(2);
-  Value greaterOperand = ifOp.getResult(3);
+  Value lesserRank =
+      rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
+  Value greaterRank =
+      rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
+  Value lesserRankOperand =
+      rewriter.create<SelectOp>(loc, lhsRankULE, op.lhs(), op.rhs());
+  Value greaterRankOperand =
+      rewriter.create<SelectOp>(loc, lhsRankULE, op.rhs(), op.lhs());
 
   // Allocate stack memory for the broadcasted extent tensor.
   Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
@@ -118,11 +110,11 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
 
   // Copy extents from greater operand that are not challenged.
   Value rankDiff =
-      rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+      rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
   rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
                          [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
                            Value extent = b.create<ExtractElementOp>(
-                               loc, greaterOperand, ValueRange{iv});
+                               loc, greaterRankOperand, ValueRange{iv});
                            b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
                            b.create<scf::YieldOp>(loc);
                          });
@@ -132,16 +124,16 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
       loc, rankDiff, greaterRank, one, llvm::None,
       [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
         Value greaterOperandExtent =
-            b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+            b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
         Value greaterOperandExtentIsOne =
             b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
         auto ifOp = b.create<IfOp>(
             loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
             [&](OpBuilder &b, Location loc) {
               Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
-              Value smallerOperandExtent = b.create<ExtractElementOp>(
-                  loc, smallerOperand, ValueRange{ivShifted});
-              b.create<scf::YieldOp>(loc, smallerOperandExtent);
+              Value lesserRankOperandExtent = b.create<ExtractElementOp>(
+                  loc, lesserRankOperand, ValueRange{ivShifted});
+              b.create<scf::YieldOp>(loc, lesserRankOperandExtent);
             },
             [&](OpBuilder &b, Location loc) {
               b.create<scf::YieldOp>(loc, greaterOperandExtent);

diff  --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
index 1f7b6d60dd4f..72349d5e44af 100644
--- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir
@@ -7,25 +7,24 @@
 // CHECK:           %[[C0:.*]] = constant 0 : index
 // CHECK:           %[[C1:.*]] = constant 1 : index
 // CHECK:           %[[RET:.*]] = shape.const_witness true
-// CHECK:           %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-// CHECK:           %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
-// CHECK:           %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
-// CHECK:             scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-// CHECK:           } else {
-// CHECK:             scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-// CHECK:           }
-// CHECK:           %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
-// CHECK:           scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
-// CHECK:             %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
-// CHECK:             %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
-// CHECK:             %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
-// CHECK:             %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
-// CHECK:             %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
-// CHECK:             %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
-// CHECK:             %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
-// CHECK:             %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
-// CHECK:             assert %[[BROADCASTISVALID]], "invalid broadcast"
+// CHECK:           %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK:           %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK:           %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK:           %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
+// CHECK:           %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
+// CHECK:           %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
+// CHECK:           scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
+// CHECK:             %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+// CHECK:             %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+// CHECK:             %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor<?xindex>
+// CHECK:             %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
+// CHECK:             %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index
+// CHECK:             %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index
+// CHECK:             %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1
+// CHECK:             %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1
+// CHECK:             assert %[[BROADCAST_IS_VALID]], "invalid broadcast"
 // CHECK:           }
 // CHECK:           return %[[RET]] : !shape.witness
 // CHECK:         }

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 01ba6abcc6c4..6207486db821 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -312,27 +312,26 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
   // CHECK: %[[C1:.*]] = constant 1 : index
   // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
   // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
-  // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
-  // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
-  // CHECK:   scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-  // CHECK: } else {
-  // CHECK:   scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
-  // CHECK: }
-  // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
-  // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+  // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+  // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+  // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+  // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor<?xindex>
+  // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor<?xindex>
+  // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref<?xindex>
+  // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index
   // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
-  // CHECK:   %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+  // CHECK:   %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
   // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
   // CHECK: }
-  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
-  // CHECK:   %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
-  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+  // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] {
+  // CHECK:   %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor<?xindex>
+  // CHECK:   %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index
   // CHECK:   %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
   // CHECK:     %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
-  // CHECK:     %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
-  // CHECK:     scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+  // CHECK:     %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor<?xindex>
+  // CHECK:     scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index
   // CHECK:   } else {
-  // CHECK:     scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+  // CHECK:     scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index
   // CHECK:   }
   // CHECK:   store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
   // CHECK: }
@@ -341,4 +340,3 @@ func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
       : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
   return
 }
-


        


More information about the Mlir-commits mailing list