[Mlir-commits] [mlir] b1c688d - [mlir] [VectorOps] Implement vector.create_mask lowering to LLVM IR
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 15 11:02:44 PDT 2020
Author: aartbik
Date: 2020-05-15T11:02:30-07:00
New Revision: b1c688dbae696e30ae5fb22677bfbfa255117f9f
URL: https://github.com/llvm/llvm-project/commit/b1c688dbae696e30ae5fb22677bfbfa255117f9f
DIFF: https://github.com/llvm/llvm-project/commit/b1c688dbae696e30ae5fb22677bfbfa255117f9f.diff
LOG: [mlir] [VectorOps] Implement vector.create_mask lowering to LLVM IR
Summary:
First, compact implementation of lowering to LLVM IR. A bit more
challenging than the constant mask due to the dynamic indices, of course.
I like to hear if there are more efficient ways of doing this in LLVM,
but this for now at least gives us a functional reference implementation.
Reviewers: nicolasvasilache, ftynse, bkramer, reidtatge, andydavis1, mehdi_amini
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79954
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/Target/vector-to-llvm-ir.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 51cb61b0eaca..eb25bf3abf85 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1146,6 +1146,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
// all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
+ populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
applyPatternsAndFoldGreedily(getOperation(), patterns);
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 4f8e2374a251..851b54beb452 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1236,6 +1236,45 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
}
};
+class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
+public:
+ using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp op,
+ PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto dstType = op.getResult().getType().cast<VectorType>();
+ auto eltType = dstType.getElementType();
+ int64_t rank = dstType.getRank();
+ Value idx = op.getOperand(0);
+
+ Value trueVal;
+ Value falseVal;
+ if (rank > 1) {
+ VectorType lowType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ trueVal = rewriter.create<vector::CreateMaskOp>(
+ loc, lowType, op.getOperands().drop_front());
+ falseVal = rewriter.create<ConstantOp>(loc, lowType,
+ rewriter.getZeroAttr(lowType));
+ }
+
+ Value result = rewriter.create<ConstantOp>(loc, dstType,
+ rewriter.getZeroAttr(dstType));
+ for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) {
+ Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
+ Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
+ if (rank > 1)
+ val = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
+ auto pos = rewriter.getI64ArrayAttr(d);
+ result =
+ rewriter.create<vector::InsertOp>(loc, dstType, val, result, pos);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
@@ -1659,6 +1698,6 @@ void mlir::vector::populateVectorContractLoweringPatterns(
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
TransposeOpLowering, OuterProductOpLowering,
- ConstantMaskOpLowering>(context);
+ ConstantMaskOpLowering, CreateMaskOpLowering>(context);
patterns.insert<ContractionOpLowering>(parameters, context);
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index a4d4fe58e8d5..72270dab1153 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -605,3 +605,49 @@ func @genbool_3d() -> vector<2x3x4xi1> {
%v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
return %v: vector<2x3x4xi1>
}
+
+// CHECK-LABEL: func @genbool_var_1d
+// CHECK-SAME: %[[A:.*0]]: index
+// CHECK-DAG: %[[VF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1>
+// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
+// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
+// CHECK: return %[[T5]] : vector<3xi1>
+
+func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
+ %0 = vector.create_mask %arg0 : vector<3xi1>
+ return %0 : vector<3xi1>
+}
+
+// CHECK-LABEL: func @genbool_var_2d
+// CHECK-SAME: %[[A:.*0]]: index
+// CHECK-SAME: %[[B:.*1]]: index
+// CHECK-DAG: %[[Z1:.*]] = constant dense<false> : vector<3xi1>
+// CHECK-DAG: %[[Z2:.*]] = constant dense<false> : vector<2x3xi1>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1>
+// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
+// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
+// CHECK: %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
+// CHECK: %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1>
+// CHECK: %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK: %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
+// CHECK: %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1>
+// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK: return %[[T11]] : vector<2x3xi1>
+
+func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
+ %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
+ return %0 : vector<2x3xi1>
+}
diff --git a/mlir/test/Target/vector-to-llvm-ir.mlir b/mlir/test/Target/vector-to-llvm-ir.mlir
index 354f70f3f4b9..4ede6ca2a5df 100644
--- a/mlir/test/Target/vector-to-llvm-ir.mlir
+++ b/mlir/test/Target/vector-to-llvm-ir.mlir
@@ -21,3 +21,11 @@ func @genbool_3d() -> vector<2x3x4xi1> {
// CHECK-LABEL: @genbool_3d()
// CHECK-NEXT: ret [2 x [3 x <4 x i1>]] {{\[+}}3 x <4 x i1>] [<4 x i1> <i1 true, i1 true, i1 true, i1 false>, <4 x i1> zeroinitializer, <4 x i1> zeroinitializer], [3 x <4 x i1>] zeroinitializer]
// note: awkward syntax to match [[
+
+func @genbool_1d_var_but_constant() -> vector<8xi1> {
+ %i = constant 0 : index
+ %v = vector.create_mask %i : vector<8xi1>
+ return %v : vector<8xi1>
+}
+// CHECK-LABEL: @genbool_1d_var_but_constant()
+// CHECK-NEXT: ret <8 x i1> zeroinitializer
More information about the Mlir-commits
mailing list