[Mlir-commits] [mlir] ccef726 - [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToLLVM)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Sep 11 09:48:54 PDT 2023
Author: Benjamin Maxwell
Date: 2023-09-11T16:47:51Z
New Revision: ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6
URL: https://github.com/llvm/llvm-project/commit/ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6
DIFF: https://github.com/llvm/llvm-project/commit/ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6.diff
LOG: [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToLLVM)
This is a follow-on to D158753, and allows the lowering of a
transfer read/write of n-D vectors with a single trailing scalable dimension
to primitive vector ops.
The final conversion to LLVM depends on D158517 and D158752, without
these patches type conversion will fail (or an assert is hit in the LLVM
backend) if the final IR contains an array of scalable vectors.
This patch adds `transform.apply_patterns.vector.lower_create_mask`
which allows the lowering of vector.create_mask/constant_mask to be
tested independently of --convert-vector-to-llvm.
Reviewed By: c-rhodes, awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D159482
Added:
mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..9e718a0c80bbf3b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -122,6 +122,17 @@ def ApplyLowerContractionPatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyLowerCreateMaskPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.lower_create_mask",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector create_mask-like operations should be lowered to
+ finer-grained vector primitives.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerMasksPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_masks",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..b388deaa46a7917 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -64,6 +64,11 @@ void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
vector::populateVectorReductionToContractPatterns(patterns);
}
+void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorMaskOpLoweringPatterns(patterns);
+}
+
void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransferDropUnitDimsPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 796bbab38dcbf68..9a828ec0b845e4a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -58,13 +58,15 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");
+ if (dstType.getScalableDims().front())
+ return rewriter.notifyMatchFailure(
+ op, "Cannot unroll leading scalable dim in dstType");
+
auto loc = op.getLoc();
- auto eltType = dstType.getElementType();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);
- VectorType lowType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
+ VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 68160bcf59e6678..bc3e47a71b43097 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -434,7 +434,7 @@ struct TransferReadToVectorLoadLowering
vectorShape.end());
for (unsigned i : broadcastedDims)
unbroadcastedVectorShape[i] = 1;
- VectorType unbroadcastedVectorType = VectorType::get(
+ VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
unbroadcastedVectorShape, read.getVectorType().getElementType());
// `vector.load` supports vector types as memref's elements only when the
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 514594240d22a1b..a1c9aa6edeaa46b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1743,6 +1743,28 @@ func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5x
// -----
+// CHECK-LABEL: func @transfer_read_1d_scalable_mask
+// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: return %[[r]] : vector<[4]xf32>
+func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
+ return %vec : vector<[4]xf32>
+}
+
+// -----
+// CHECK-LABEL: func @transfer_write_1d_scalable_mask
+// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
+func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
+ return
+}
+
+// -----
+
func.func @genbool_0d_f() -> vector<i1> {
%0 = vector.constant_mask [0] : vector<i1>
return %0 : vector<i1>
diff --git a/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
new file mode 100644
index 000000000000000..138e647c751ab8f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
+
+// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable(
+// CHECK-SAME: %[[arg:.*]]: index) -> vector<3x[4]xi1> {
+// CHECK-NEXT: %[[zero_mask_1d:.*]] = arith.constant dense<false> : vector<[4]xi1>
+// CHECK-NEXT: %[[zero_mask_2d:.*]] = arith.constant dense<false> : vector<3x[4]xi1>
+// CHECK-NEXT: %[[create_mask_1d:.*]] = vector.create_mask %[[arg]] : vector<[4]xi1>
+// CHECK-NEXT: %[[res_0:.*]] = vector.insert %[[create_mask_1d]], %[[zero_mask_2d]] [0] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: %[[res_1:.*]] = vector.insert %[[create_mask_1d]], %[[res_0]] [1] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: %[[res_2:.*]] = vector.insert %[[zero_mask_1d]], %[[res_1]] [2] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: return %[[res_2]] : vector<3x[4]xi1>
+func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<3x[4]xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %a : vector<3x[4]xi1>
+ return %mask : vector<3x[4]xi1>
+}
+
+// -----
+
+/// The following cannot be lowered as the current lowering requires unrolling
+/// the leading dim.
+
+// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable(
+// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> {
+// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1>
+func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> {
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %a, %c1 : vector<[4]x4xi1>
+ return %mask : vector<[4]x4xi1>
+}
+
+transform.sequence failures(suppress) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_create_mask
+ } : !transform.any_op
+}
More information about the Mlir-commits
mailing list