[Mlir-commits] [mlir] b1c688d - [mlir] [VectorOps] Implement vector.create_mask lowering to LLVM IR

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 15 11:02:44 PDT 2020


Author: aartbik
Date: 2020-05-15T11:02:30-07:00
New Revision: b1c688dbae696e30ae5fb22677bfbfa255117f9f

URL: https://github.com/llvm/llvm-project/commit/b1c688dbae696e30ae5fb22677bfbfa255117f9f
DIFF: https://github.com/llvm/llvm-project/commit/b1c688dbae696e30ae5fb22677bfbfa255117f9f.diff

LOG: [mlir] [VectorOps] Implement vector.create_mask lowering to LLVM IR

Summary:
First, compact implementation of lowering to LLVM IR. A bit more
challenging than the constant mask due to the dynamic indices, of course.
I like to hear if there are more efficient ways of doing this in LLVM,
but this for now at least gives us a functional reference implementation.

Reviewers: nicolasvasilache, ftynse, bkramer, reidtatge, andydavis1, mehdi_amini

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79954

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-contract-transforms.mlir
    mlir/test/Target/vector-to-llvm-ir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 51cb61b0eaca..eb25bf3abf85 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1146,6 +1146,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   // all contraction operations. Also applies folding and DCE.
   {
     OwningRewritePatternList patterns;
+    populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
     populateVectorSlicesLoweringPatterns(patterns, &getContext());
     populateVectorContractLoweringPatterns(patterns, &getContext());
     applyPatternsAndFoldGreedily(getOperation(), patterns);

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 4f8e2374a251..851b54beb452 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1236,6 +1236,45 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
   }
 };
 
+class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+  using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto dstType = op.getResult().getType().cast<VectorType>();
+    auto eltType = dstType.getElementType();
+    int64_t rank = dstType.getRank();
+    Value idx = op.getOperand(0);
+
+    Value trueVal;
+    Value falseVal;
+    if (rank > 1) {
+      VectorType lowType =
+          VectorType::get(dstType.getShape().drop_front(), eltType);
+      trueVal = rewriter.create<vector::CreateMaskOp>(
+          loc, lowType, op.getOperands().drop_front());
+      falseVal = rewriter.create<ConstantOp>(loc, lowType,
+                                             rewriter.getZeroAttr(lowType));
+    }
+
+    Value result = rewriter.create<ConstantOp>(loc, dstType,
+                                               rewriter.getZeroAttr(dstType));
+    for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) {
+      Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
+      Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
+      if (rank > 1)
+        val = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
+      auto pos = rewriter.getI64ArrayAttr(d);
+      result =
+          rewriter.create<vector::InsertOp>(loc, dstType, val, result, pos);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /// Progressive lowering of ContractionOp.
 /// One:
 ///   %x = vector.contract with at least one free/batch dimension
@@ -1659,6 +1698,6 @@ void mlir::vector::populateVectorContractLoweringPatterns(
   patterns.insert<ShapeCastOp2DDownCastRewritePattern,
                   ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
                   TransposeOpLowering, OuterProductOpLowering,
-                  ConstantMaskOpLowering>(context);
+                  ConstantMaskOpLowering, CreateMaskOpLowering>(context);
   patterns.insert<ContractionOpLowering>(parameters, context);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index a4d4fe58e8d5..72270dab1153 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -605,3 +605,49 @@ func @genbool_3d() -> vector<2x3x4xi1> {
   %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
   return %v: vector<2x3x4xi1>
 }
+
+// CHECK-LABEL: func @genbool_var_1d
+// CHECK-SAME: %[[A:.*0]]: index
+// CHECK-DAG:  %[[VF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK-DAG:  %[[C2:.*]] = constant 2 : index
+// CHECK:      %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
+// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1>
+// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
+// CHECK:      %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index
+// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
+// CHECK:      return %[[T5]] : vector<3xi1>
+
+func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
+  %0 = vector.create_mask %arg0 : vector<3xi1>
+  return %0 : vector<3xi1>
+}
+
+// CHECK-LABEL: func @genbool_var_2d
+// CHECK-SAME: %[[A:.*0]]: index
+// CHECK-SAME: %[[B:.*1]]: index
+// CHECK-DAG:  %[[Z1:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-DAG:  %[[Z2:.*]] = constant dense<false> : vector<2x3xi1>
+// CHECK-DAG:  %[[C0:.*]] = constant 0 : index
+// CHECK-DAG:  %[[C1:.*]] = constant 1 : index
+// CHECK-DAG:  %[[C2:.*]] = constant 2 : index
+// CHECK:      %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index
+// CHECK:      %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1>
+// CHECK:      %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index
+// CHECK:      %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
+// CHECK:      %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index
+// CHECK:      %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
+// CHECK:      %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
+// CHECK:      %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1>
+// CHECK:      %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
+// CHECK:      %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1>
+// CHECK:      %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK:      return %[[T11]] : vector<2x3xi1>
+
+func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
+  %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
+  return %0 : vector<2x3xi1>
+}

diff  --git a/mlir/test/Target/vector-to-llvm-ir.mlir b/mlir/test/Target/vector-to-llvm-ir.mlir
index 354f70f3f4b9..4ede6ca2a5df 100644
--- a/mlir/test/Target/vector-to-llvm-ir.mlir
+++ b/mlir/test/Target/vector-to-llvm-ir.mlir
@@ -21,3 +21,11 @@ func @genbool_3d() -> vector<2x3x4xi1> {
 // CHECK-LABEL: @genbool_3d()
 // CHECK-NEXT: ret [2 x [3 x <4 x i1>]] {{\[+}}3 x <4 x i1>] [<4 x i1> <i1 true, i1 true, i1 true, i1 false>, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer], [3 x <4 x i1>] zeroinitializer]
 // note: awkward syntax to match [[
+
+func @genbool_1d_var_but_constant() -> vector<8xi1> {
+  %i = constant 0 : index
+  %v = vector.create_mask %i : vector<8xi1>
+  return %v : vector<8xi1>
+}
+// CHECK-LABEL: @genbool_1d_var_but_constant()
+// CHECK-NEXT: ret <8 x i1> zeroinitializer


        


More information about the Mlir-commits mailing list