[Mlir-commits] [mlir] 55d09df - [mlir] [VectorOps] Improve vector.create_mask lowering
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 23 14:34:56 PDT 2020
Author: aartbik
Date: 2020-06-23T14:33:41-07:00
New Revision: 55d09dfc7b147dbb74ae62173f4d5b078b19e328
URL: https://github.com/llvm/llvm-project/commit/55d09dfc7b147dbb74ae62173f4d5b078b19e328
DIFF: https://github.com/llvm/llvm-project/commit/55d09dfc7b147dbb74ae62173f4d5b078b19e328.diff
LOG: [mlir] [VectorOps] Improve vector.create_mask lowering
Use vector compares for the 1-D case. This approach scales much better
than generating insertion operations, and exposes SIMD directly to backend.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D82402
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index a9207230278f..effae86e4597 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1328,6 +1328,8 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
int64_t trueDim = dimSizes[0].cast<IntegerAttr>().getInt();
if (rank == 1) {
+ // Express constant 1-D case in explicit vector form:
+ // [T,..,T,F,..,F].
SmallVector<bool, 4> values(dstType.getDimSize(0));
for (int64_t d = 0; d < trueDim; d++)
values[d] = true;
@@ -1364,8 +1366,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
/// %1 = select %0, %l, %zeroes |
/// %r = vector.insert %1, %pr [i] | d-times
/// %x = ....
-/// When rank == 1, the selection operator is not needed,
-/// and we can assign the true/false value right away.
+/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern;
@@ -1375,30 +1376,41 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
auto loc = op.getLoc();
auto dstType = op.getResult().getType().cast<VectorType>();
auto eltType = dstType.getElementType();
+ int64_t dim = dstType.getDimSize(0);
int64_t rank = dstType.getRank();
Value idx = op.getOperand(0);
- Value trueVal;
- Value falseVal;
- if (rank > 1) {
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
- trueVal = rewriter.create<vector::CreateMaskOp>(
- loc, lowType, op.getOperands().drop_front());
- falseVal = rewriter.create<ConstantOp>(loc, lowType,
- rewriter.getZeroAttr(lowType));
+ if (rank == 1) {
+ // Express dynamic 1-D case in explicit vector form:
+ // mask = [0,1,..,n-1] < [a,a,..,a]
+ SmallVector<int64_t, 4> values(dim);
+ for (int64_t d = 0; d < dim; d++)
+ values[d] = d;
+ Value indices =
+ rewriter.create<ConstantOp>(loc, rewriter.getI64VectorAttr(values));
+ Value bound =
+ rewriter.create<IndexCastOp>(loc, rewriter.getI64Type(), idx);
+ Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
+ rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, indices,
+ bounds);
+ return success();
}
+ VectorType lowType =
+ VectorType::get(dstType.getShape().drop_front(), eltType);
+ Value trueVal = rewriter.create<vector::CreateMaskOp>(
+ loc, lowType, op.getOperands().drop_front());
+ Value falseVal = rewriter.create<ConstantOp>(loc, lowType,
+ rewriter.getZeroAttr(lowType));
Value result = rewriter.create<ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
- for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; d++) {
+ for (int64_t d = 0; d < dim; d++) {
Value bnd = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, bnd, idx);
- if (rank > 1)
- val = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
+ Value sel = rewriter.create<SelectOp>(loc, val, trueVal, falseVal);
auto pos = rewriter.getI64ArrayAttr(d);
result =
- rewriter.create<vector::InsertOp>(loc, dstType, val, result, pos);
+ rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index d4c68ab2cebd..a07675515a9b 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -710,18 +710,12 @@ func @genbool_3d() -> vector<2x3x4xi1> {
}
// CHECK-LABEL: func @genbool_var_1d
-// CHECK-SAME: %[[A:.*0]]: index
-// CHECK-DAG: %[[VF:.*]] = constant dense<false> : vector<3xi1>
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = constant 2 : index
-// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[VF]] [0] : i1 into vector<3xi1>
-// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
-// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[A]] : index
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
-// CHECK: return %[[T5]] : vector<3xi1>
+// CHECK-SAME: %[[A:.*]]: index
+// CHECK: %[[C1:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
+// CHECK: %[[T0:.*]] = index_cast %[[A]] : index to i64
+// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
+// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[T1]] : vector<3xi64>
+// CHECK: return %[[T2]] : vector<3xi1>
func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
%0 = vector.create_mask %arg0 : vector<3xi1>
@@ -731,24 +725,21 @@ func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
// CHECK-LABEL: func @genbool_var_2d
// CHECK-SAME: %[[A:.*0]]: index
// CHECK-SAME: %[[B:.*1]]: index
-// CHECK-DAG: %[[Z1:.*]] = constant dense<false> : vector<3xi1>
-// CHECK-DAG: %[[Z2:.*]] = constant dense<false> : vector<2x3xi1>
-// CHECK-DAG: %[[C0:.*]] = constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = constant 1 : index
-// CHECK-DAG: %[[C2:.*]] = constant 2 : index
-// CHECK: %[[T0:.*]] = cmpi "slt", %[[C0]], %[[B]] : index
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z1]] [0] : i1 into vector<3xi1>
-// CHECK: %[[T2:.*]] = cmpi "slt", %[[C1]], %[[B]] : index
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : i1 into vector<3xi1>
-// CHECK: %[[T4:.*]] = cmpi "slt", %[[C2]], %[[B]] : index
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : i1 into vector<3xi1>
-// CHECK: %[[T6:.*]] = cmpi "slt", %[[C0]], %[[A]] : index
-// CHECK: %[[T7:.*]] = select %[[T6]], %[[T5]], %[[Z1]] : vector<3xi1>
-// CHECK: %[[T8:.*]] = vector.insert %7, %[[Z2]] [0] : vector<3xi1> into vector<2x3xi1>
-// CHECK: %[[T9:.*]] = cmpi "slt", %[[C1]], %[[A]] : index
-// CHECK: %[[T10:.*]] = select %[[T9]], %[[T5]], %[[Z1]] : vector<3xi1>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T8]] [1] : vector<3xi1> into vector<2x3xi1>
-// CHECK: return %[[T11]] : vector<2x3xi1>
+// CHECK: %[[CI:.*]] = constant dense<[0, 1, 2]> : vector<3xi64>
+// CHECK: %[[CF:.*]] = constant dense<false> : vector<3xi1>
+// CHECK: %[[C2:.*]] = constant dense<false> : vector<2x3xi1>
+// CHECK: %[[c0:.*]] = constant 0 : index
+// CHECK: %[[c1:.*]] = constant 1 : index
+// CHECK: %[[T0:.*]] = index_cast %[[B]] : index to i64
+// CHECK: %[[T1:.*]] = splat %[[T0]] : vector<3xi64>
+// CHECK: %[[T2:.*]] = cmpi "slt", %[[CI]], %[[T1]] : vector<3xi64>
+// CHECK: %[[T3:.*]] = cmpi "slt", %[[c0]], %[[A]] : index
+// CHECK: %[[T4:.*]] = select %[[T3]], %[[T2]], %[[CF]] : vector<3xi1>
+// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
+// CHECK: %[[T6:.*]] = cmpi "slt", %[[c1]], %[[A]] : index
+// CHECK: %[[T7:.*]] = select %[[T6]], %[[T2]], %[[CF]] : vector<3xi1>
+// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T5]] [1] : vector<3xi1> into vector<2x3xi1>
+// CHECK: return %[[T8]] : vector<2x3xi1>
func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
More information about the Mlir-commits
mailing list