[Mlir-commits] [mlir] 55d09df - [mlir] [VectorOps] Improve vector.create_mask lowering

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 23 14:34:56 PDT 2020


Author: aartbik
Date: 2020-06-23T14:33:41-07:00
New Revision: 55d09dfc7b147dbb74ae62173f4d5b078b19e328

URL: https://github.com/llvm/llvm-project/commit/55d09dfc7b147dbb74ae62173f4d5b078b19e328
DIFF: https://github.com/llvm/llvm-project/commit/55d09dfc7b147dbb74ae62173f4d5b078b19e328.diff

LOG: [mlir] [VectorOps] Improve vector.create_mask lowering

Use vector compares for the 1-D case. This approach scales much better
than generating insertion operations, and exposes SIMD directly to backend.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D82402

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a9207230278f..effae86e4597 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1328,6 +1328,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
     int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
 
     if (rank == 1) {
+      // Express constant 1-D case in explicit vector form:
+      //   [T,..,T,F,..,F].
       SmallVector<bool, 4> values(dstType.getDimSize(0));
       for (int64_t d = 0; d < trueDim; d++)
         values[d] = true;
@@ -1364,8 +1366,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
 ///   %1 = select %0, %l, %zeroes    |
 ///   %r = vector.insert %1, %pr [i] | d-times
 ///   %x = ....
-/// When rank == 1, the selection operator is not needed,
-/// and we can assign the true/false value right away.
+/// until a one-dimensional vector is reached.
 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
 public:
   using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
@@ -1375,30 +1376,41 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
     auto loc = op.getLoc();
     auto dstType = op.getResult().getType().cast<VectorType>();
     auto eltType = dstType.getElementType();
+    int64_t dim = dstType.getDimSize(0);
     int64_t rank = dstType.getRank();
     Value idx = op.getOperand(0);
 
-    Value trueVal;
-    Value falseVal;
-    if (rank > 1) {
-      VectorType lowType =
-          VectorType::get(dstType.getShape().drop_front(), eltType);
-      trueVal = rewriter.create<vector::CreateMaskOp>(
-          loc, lowType, op.getOperands().drop_front());
-      falseVal = rewriter.create<ConstantOp>(loc, lowType,
-                                             rewriter.getZeroAttr(lowType));
+    if (rank == 1) {
+      // Express dynamic 1-D case in explicit vector form:
+      //   mask = [0,1,..,n-1] < [a,a,..,a]
+      SmallVector<int64_t, 4> values(dim);
+      for (int64_t d = 0; d < dim; d++)
+        values[d] = d;
+      Value indices =
+          rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
+      Value bound =
+          rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
+      Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
+                                          bounds);
+      return success();
     }
 
+    VectorType lowType =
+        VectorType::get(dstType.getShape().drop_front(), eltType);
+    Value trueVal = rewriter.create<vector::CreateMaskOp>(
+        loc, lowType, op.getOperands().drop_front());
+    Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
+                                                 rewriter.getZeroAttr(lowType));
     Value result = rewriter.create<ConstantOp>(loc, dstType,
                                                rewriter.getZeroAttr(dstType));
-    for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) {
+    for (int64_t d = 0; d < dim; d++) {
       Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
       Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
-      if (rank > 1)
-        val = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
+      Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
       auto pos = rewriter.getI64ArrayAttr(d);
       result =
-          rewriter.create<vector::InsertOp>(loc, dstType, val, result, pos);
+          rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
     }
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index d4c68ab2cebd..a07675515a9b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -710,18 +710,12 @@ func @genbool_3d() -> vector<2x3x4xi1> {
 }
 
 // CHECK-LABEL: func @genbool_var_1d
-// CHECK-SAME: %[[A:.*0]]: index
-// CHECK-DAG:  %[[VF:.*]] = constant dense<false> : vector<3xi1>
-// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
-// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
-// CHECK-DAG:  %[[C2:.*]] = constant 2 : index
-// CHECK:      %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
-// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1>
-// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
-// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
-// CHECK:      %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index
-// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
-// CHECK:      return %[[T5]] : vector<3xi1>
+// CHECK-SAME: %[[A:.*]]: index
+// CHECK:      %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
+// CHECK:      %[[T0:.*]] = index_cast %[[A]] : index to i64
+// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
+// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
+// CHECK:      return %[[T2]] : vector<3xi1>
 
 func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
   %0 = vector.create_mask %arg0 : vector<3xi1>
@@ -731,24 +725,21 @@ func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
 // CHECK-LABEL: func @genbool_var_2d
 // CHECK-SAME: %[[A:.*0]]: index
 // CHECK-SAME: %[[B:.*1]]: index
-// CHECK-DAG:  %[[Z1:.*]] = constant dense<false> : vector<3xi1>
-// CHECK-DAG:  %[[Z2:.*]] = constant dense<false> : vector<2x3xi1>
-// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
-// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
-// CHECK-DAG:  %[[C2:.*]] = constant 2 : index
-// CHECK:      %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index
-// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1>
-// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index
-// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
-// CHECK:      %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index
-// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
-// CHECK:      %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
-// CHECK:      %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1>
-// CHECK:      %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK:      %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
-// CHECK:      %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1>
-// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1>
-// CHECK:      return %[[T11]] : vector<2x3xi1>
+// CHECK:      %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
+// CHECK:      %[[CF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK:      %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
+// CHECK:      %[[c0:.*]] = constant 0 : index
+// CHECK:      %[[c1:.*]] = constant 1 : index
+// CHECK:      %[[T0:.*]] = index_cast %[[B]] : index to i64
+// CHECK:      %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
+// CHECK:      %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
+// CHECK:      %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK:      %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
+// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK:      %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
+// CHECK:      %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      return %[[T8]] : vector<2x3xi1>
 
 func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
   %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>


        


More information about the Mlir-commits mailing list