[Mlir-commits] [mlir] 408553d - [mlir][Vector] Support 0-D vectors in `CreateMaskOp`
Nicolas Vasilache
llvmlistbot at llvm.org
Sun Dec 12 05:35:58 PST 2021
Author: Nicolas Vasilache
Date: 2021-12-12T13:32:29Z
New Revision: 408553dd96792929b2468bcd2d8e4764ae7c2c9e
URL: https://github.com/llvm/llvm-project/commit/408553dd96792929b2468bcd2d8e4764ae7c2c9e
DIFF: https://github.com/llvm/llvm-project/commit/408553dd96792929b2468bcd2d8e4764ae7c2c9e.diff
LOG: [mlir][Vector] Support 0-D vectors in `CreateMaskOp`
The 0-D case gets lowered in almost the same way that the 1-D case does
in VectorCreateMaskOpConversion. I also had to slightly update the
verifier for the op to always require exactly 1 operand in the 0-D case.
Depends On D115220
Reviewed by: ftynse
Differential revision: https://reviews.llvm.org/D115221
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 14afb47004f43..a05a377731665 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -2149,7 +2149,8 @@ def Vector_ConstantMaskOp :
def Vector_CreateMaskOp :
Vector_Op<"create_mask", [NoSideEffect]>,
- Arguments<(ins Variadic<Index>:$operands)>, Results<(outs VectorOf<[I1]>)> {
+ Arguments<(ins Variadic<Index>:$operands)>,
+ Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
let description = [{
Creates and returns a vector mask where elements of the result vector
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 9768483231c9d..27a0fc6d9dcbe 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3968,11 +3968,17 @@ static LogicalResult verify(ConstantMaskOp &op) {
//===----------------------------------------------------------------------===//
static LogicalResult verify(CreateMaskOp op) {
+ auto vectorType = op.getResult().getType().cast<VectorType>();
// Verify that an operand was specified for each result vector each dimension.
- if (op.getNumOperands() !=
- op.getResult().getType().cast<VectorType>().getRank())
+ if (vectorType.getRank() == 0) {
+ if (op->getNumOperands() != 1)
+ return op.emitOpError(
+ "must specify exactly one operand for 0-D create_mask");
+ } else if (op.getNumOperands() !=
+ op.getResult().getType().cast<VectorType>().getRank()) {
return op.emitOpError(
"must specify an operand for each result vector dimension");
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 30335b70f4a22..008db8139df10 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -677,16 +677,17 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
- auto loc = op.getLoc();
auto dstType = op.getResult().getType().cast<VectorType>();
+ int64_t rank = dstType.getRank();
+ if (rank <= 1)
+ return rewriter.notifyMatchFailure(
+ op, "0-D and 1-D vectors are handled separately");
+
+ auto loc = op.getLoc();
auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
- int64_t rank = dstType.getRank();
Value idx = op.getOperand(0);
- if (rank == 1)
- return failure(); // leave for lowering
-
VectorType lowType =
VectorType::get(dstType.getShape().drop_front(), eltType);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
@@ -2717,6 +2718,8 @@ static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc,
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
+// If `dim == 0` then the result will be a 0-D vector.
+//
// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
// much more compact, IR for this operation, but LLVM eventually
// generates more elaborate instructions for this intrinsic since it
@@ -2728,19 +2731,23 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
- Value indices;
- Type idxType;
- if (indexOptimizations) {
- indices = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32VectorAttr(
- llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))));
- idxType = rewriter.getI32Type();
+ Type idxType =
+ indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type();
+ DenseIntElementsAttr indicesAttr;
+ if (dim == 0 && indexOptimizations) {
+ indicesAttr = DenseIntElementsAttr::get(
+ VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
+ } else if (dim == 0) {
+ indicesAttr = DenseIntElementsAttr::get(
+ VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
+ } else if (indexOptimizations) {
+ indicesAttr = rewriter.getI32VectorAttr(
+ llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
} else {
- indices = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI64VectorAttr(
- llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))));
- idxType = rewriter.getI64Type();
+ indicesAttr = rewriter.getI64VectorAttr(
+ llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
}
+ Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr);
// Add in an offset if requested.
if (off) {
Value o = createCastToIndexLike(rewriter, loc, idxType, *off);
@@ -2806,7 +2813,7 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
const bool indexOptimizations;
};
-/// Conversion pattern for a vector.create_mask (1-D only).
+/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
class VectorCreateMaskOpConversion
: public OpRewritePattern<vector::CreateMaskOp> {
public:
@@ -2819,13 +2826,13 @@ class VectorCreateMaskOpConversion
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
int64_t rank = dstType.getRank();
- if (rank == 1) {
- rewriter.replaceOp(
- op, buildVectorComparison(rewriter, op, indexOptimizations,
- dstType.getDimSize(0), op.getOperand(0)));
- return success();
- }
- return failure();
+ if (rank > 1)
+ return failure();
+ rewriter.replaceOp(
+ op, buildVectorComparison(rewriter, op, indexOptimizations,
+ rank == 0 ? 0 : dstType.getDimSize(0),
+ op.getOperand(0)));
+ return success();
}
private:
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ee3c7567dd3aa..ccb75b8a606a2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1442,6 +1442,35 @@ func @genbool_2d() -> vector<4x4xi1> {
// -----
+func @create_mask_0d(%a : index) -> vector<i1> {
+ %v = vector.create_mask %a : vector<i1>
+ return %v: vector<i1>
+}
+
+// CHECK-LABEL: func @create_mask_0d
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK: %[[indices:.*]] = arith.constant dense<0> : vector<i32>
+// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<i32>
+// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<i32>
+// CHECK: return %[[result]] : vector<i1>
+// -----
+
+func @create_mask_1d(%a : index) -> vector<4xi1> {
+ %v = vector.create_mask %a : vector<4xi1>
+ return %v: vector<4xi1>
+}
+
+// CHECK-LABEL: func @create_mask_1d
+// CHECK-SAME: %[[arg:.*]]: index
+// CHECK: %[[indices:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
+// CHECK: %[[arg_i32:.*]] = arith.index_cast %[[arg]] : index to i32
+// CHECK: %[[bounds:.*]] = splat %[[arg_i32]] : vector<4xi32>
+// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
+// CHECK: return %[[result]] : vector<4xi1>
+
+// -----
+
func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
: vector<16xf32> -> vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 195e720d2ad99..a384d42ef6112 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -874,6 +874,24 @@ func @contraction(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<
// -----
+func @create_mask_0d_no_operands() {
+ %c1 = arith.constant 1 : index
+ // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
+ %0 = vector.create_mask : vector<i1>
+}
+
+// -----
+
+func @create_mask_0d_many_operands() {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ // expected-error at +1 {{must specify exactly one operand for 0-D create_mask}}
+ %0 = vector.create_mask %c1, %c2, %c3 : vector<i1>
+}
+
+// -----
+
func @create_mask() {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index 67a4257fa35dc..572280f570737 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -93,6 +93,18 @@ func @arith_cmpi_0d(%smaller : vector<i32>, %bigger : vector<i32>) {
return
}
+func @create_mask_0d(%zero : index, %one : index) {
+ %zero_mask = vector.create_mask %zero : vector<i1>
+ // CHECK: ( 0 )
+ vector.print %zero_mask : vector<i1>
+
+ %one_mask = vector.create_mask %one : vector<i1>
+ // CHECK: ( 1 )
+ vector.print %one_mask : vector<i1>
+
+ return
+}
+
func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -115,5 +127,9 @@ func @entry() {
%bigger = arith.constant dense<4242> : vector<i32>
call @arith_cmpi_0d(%smaller, %bigger) : (vector<i32>, vector<i32>) -> ()
+ %zero_idx = arith.constant 0 : index
+ %one_idx = arith.constant 1 : index
+ call @create_mask_0d(%zero_idx, %one_idx) : (index, index) -> ()
+
return
}
More information about the Mlir-commits
mailing list