[Mlir-commits] [mlir] ccef726 - [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToLLVM)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Sep 11 09:48:54 PDT 2023


Author: Benjamin Maxwell
Date: 2023-09-11T16:47:51Z
New Revision: ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6

URL: https://github.com/llvm/llvm-project/commit/ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6
DIFF: https://github.com/llvm/llvm-project/commit/ccef726d09b1ffadfae6b1d1d986ae2f6d25a6a6.diff

LOG: [mlir][VectorOps] Don't drop scalable dims when lowering transfer_reads/writes (in VectorToLLVM)

This is a follow-on to D158753, and allows the lowering of a
transfer read/write of n-D vectors with a single trailing scalable dimension
to primitive vector ops.

The final conversion to LLVM depends on D158517 and D158752, without
these patches type conversion will fail (or an assert is hit in the LLVM
backend) if the final IR contains an array of scalable vectors.

This patch adds `transform.apply_patterns.vector.lower_create_mask`
which allows the lowering of vector.create_mask/constant_mask to be
tested independently of --convert-vector-to-llvm.

Reviewed By: c-rhodes, awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D159482

Added: 
    mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..9e718a0c80bbf3b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -122,6 +122,17 @@ def ApplyLowerContractionPatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyLowerCreateMaskPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.lower_create_mask",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector create_mask-like operations should be lowered to
+    finer-grained vector primitives.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerMasksPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_masks",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..b388deaa46a7917 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -64,6 +64,11 @@ void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
   vector::populateVectorReductionToContractPatterns(patterns);
 }
 
+void transform::ApplyLowerCreateMaskPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorMaskOpLoweringPatterns(patterns);
+}
+
 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorTransferDropUnitDimsPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 796bbab38dcbf68..9a828ec0b845e4a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -58,13 +58,15 @@ class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
       return rewriter.notifyMatchFailure(
           op, "0-D and 1-D vectors are handled separately");
 
+    if (dstType.getScalableDims().front())
+      return rewriter.notifyMatchFailure(
+          op, "Cannot unroll leading scalable dim in dstType");
+
     auto loc = op.getLoc();
-    auto eltType = dstType.getElementType();
     int64_t dim = dstType.getDimSize(0);
     Value idx = op.getOperand(0);
 
-    VectorType lowType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
+    VectorType lowType = VectorType::Builder(dstType).dropDim(0);
     Value trueVal = rewriter.create<vector::CreateMaskOp>(
         loc, lowType, op.getOperands().drop_front());
     Value falseVal = rewriter.create<arith::ConstantOp>(

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 68160bcf59e6678..bc3e47a71b43097 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -434,7 +434,7 @@ struct TransferReadToVectorLoadLowering
                                                   vectorShape.end());
     for (unsigned i : broadcastedDims)
       unbroadcastedVectorShape[i] = 1;
-    VectorType unbroadcastedVectorType = VectorType::get(
+    VectorType unbroadcastedVectorType = read.getVectorType().cloneWith(
         unbroadcastedVectorShape, read.getVectorType().getElementType());
 
     // `vector.load` supports vector types as memref's elements only when the

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 514594240d22a1b..a1c9aa6edeaa46b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1743,6 +1743,28 @@ func.func @transfer_read_1d_mask(%A : memref<?xf32>, %base : index) -> vector<5x
 
 // -----
 
+// CHECK-LABEL: func @transfer_read_1d_scalable_mask
+// CHECK: %[[passtru:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %[[passtru]] {alignment = 4 : i32} : (!llvm.ptr, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: return %[[r]] : vector<[4]xf32>
+func.func @transfer_read_1d_scalable_mask(%arg0: memref<1x?xf32>, %mask: vector<[4]xi1>) -> vector<[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %vec = vector.transfer_read %arg0[%c0, %c0], %pad, %mask {in_bounds = [true]} : memref<1x?xf32>, vector<[4]xf32>
+  return %vec : vector<[4]xf32>
+}
+
+// -----
+// CHECK-LABEL: func @transfer_write_1d_scalable_mask
+// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.ptr
+func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<[4]xf32>, %mask: vector<[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg0[%c0, %c0], %mask {in_bounds = [true]} : vector<[4]xf32>, memref<1x?xf32>
+  return
+}
+
+// -----
+
 func.func @genbool_0d_f() -> vector<i1> {
   %0 = vector.constant_mask [0] : vector<i1>
   return %0 : vector<i1>

diff  --git a/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
new file mode 100644
index 000000000000000..138e647c751ab8f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
+
+// CHECK-LABEL: func.func @create_mask_2d_trailing_scalable(
+// CHECK-SAME: %[[arg:.*]]: index) -> vector<3x[4]xi1> {
+// CHECK-NEXT: %[[zero_mask_1d:.*]] = arith.constant dense<false> : vector<[4]xi1>
+// CHECK-NEXT: %[[zero_mask_2d:.*]] = arith.constant dense<false> : vector<3x[4]xi1>
+// CHECK-NEXT: %[[create_mask_1d:.*]] = vector.create_mask %[[arg]] : vector<[4]xi1>
+// CHECK-NEXT: %[[res_0:.*]] = vector.insert %[[create_mask_1d]], %[[zero_mask_2d]] [0] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: %[[res_1:.*]] = vector.insert %[[create_mask_1d]], %[[res_0]] [1] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: %[[res_2:.*]] = vector.insert %[[zero_mask_1d]], %[[res_1]] [2] : vector<[4]xi1> into vector<3x[4]xi1>
+// CHECK-NEXT: return %[[res_2]] : vector<3x[4]xi1>
+func.func @create_mask_2d_trailing_scalable(%a: index) -> vector<3x[4]xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %a : vector<3x[4]xi1>
+  return %mask : vector<3x[4]xi1>
+}
+
+// -----
+
+/// The following cannot be lowered as the current lowering requires unrolling
+/// the leading dim.
+
+// CHECK-LABEL: func.func @cannot_create_mask_2d_leading_scalable(
+// CHECK-SAME: %[[arg:.*]]: index) -> vector<[4]x4xi1> {
+// CHECK: %{{.*}} = vector.create_mask %[[arg]], %{{.*}} : vector<[4]x4xi1>
+func.func @cannot_create_mask_2d_leading_scalable(%a: index) -> vector<[4]x4xi1> {
+  %c1 = arith.constant 1 : index
+  %mask = vector.create_mask %a, %c1 : vector<[4]x4xi1>
+  return %mask : vector<[4]x4xi1>
+}
+
+transform.sequence failures(suppress) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_create_mask
+  } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list