[Mlir-commits] [mlir] f2b89c7 - [mlir][Vector] Use create_mask in transfer mask materializations

Javier Setoain llvmlistbot at llvm.org
Tue Mar 8 01:09:39 PST 2022


Author: Javier Setoain
Date: 2022-03-08T09:02:50Z
New Revision: f2b89c7ae083d8c99f9efc7cb90f5d3b63048e89

URL: https://github.com/llvm/llvm-project/commit/f2b89c7ae083d8c99f9efc7cb90f5d3b63048e89
DIFF: https://github.com/llvm/llvm-project/commit/f2b89c7ae083d8c99f9efc7cb90f5d3b63048e89.diff

LOG: [mlir][Vector] Use create_mask in transfer mask materializations

Currently, the transfer mask is materialized by generating the vector
comparison: [offset + 0, .., offset + length - 1] < [dim, .., dim]

A better alternative is to materialize the transfer mask by using the
operation: `vector.create_mask (dim - offset)`, which will generate
simpler code and compose better with scalable vectors.

Differential Revision: https://reviews.llvm.org/D120487

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ee6429d0abd77..2b22412d6fc36 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2259,22 +2259,21 @@ struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
     Location loc = xferOp->getLoc();
     VectorType vtp = xferOp.getVectorType();
 
-    // * Create a vector with linear indices [ 0 .. vector_length - 1 ].
-    // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-    // * Let dim the memref dimension, compute the vector comparison mask
-    //   (in-bounds mask):
-    //   [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
+    // Create the in-bounds mask with all elements between [0 .. dim - offset)
+    // set and [dim - offset .. vector_length) unset.
     //
     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
     //       dimensions here.
-    unsigned vecWidth = vtp.getNumElements();
     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
     Value off = xferOp.indices()[lastIndex];
     Value dim =
         vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex);
-    Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations,
-                                       vecWidth, dim, &off);
-
+    Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
+    Value mask = rewriter.create<vector::CreateMaskOp>(
+        loc,
+        VectorType::get(vtp.getShape(), rewriter.getI1Type(),
+                        vtp.getNumScalableDims()),
+        b);
     if (xferOp.mask()) {
       // Intersect the in-bounds with the mask specified as an op parameter.
       mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask());

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 6599323b573d5..7ed8f96789bb1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -25,16 +25,26 @@ func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
 }
 
 // CMP32-LABEL: @transfer_read_1d
+// CMP32: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
+// CMP32: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>
+// CMP32: %[[S:.*]] = arith.subi %[[D]], %[[OFF]] : index
 // CMP32: %[[C:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
-// CMP32: %[[A:.*]] = arith.addi %{{.*}}, %[[C]] : vector<16xi32>
-// CMP32: %[[M:.*]] = arith.cmpi slt, %[[A]], %{{.*}} : vector<16xi32>
+// CMP32: %[[B:.*]] = arith.index_cast %[[S]] : index to i32
+// CMP32: %[[B0:.*]] = llvm.insertelement %[[B]], %{{.*}} : vector<16xi32>
+// CMP32: %[[BV:.*]] = llvm.shufflevector %[[B0]], {{.*}} : vector<16xi32>, vector<16xi32>
+// CMP32: %[[M:.*]] = arith.cmpi slt, %[[C]], %[[BV]] : vector<16xi32>
 // CMP32: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
 // CMP32: return %[[L]] : vector<16xf32>
 
-// CMP64-LABEL: @transfer_read_1d
+// CMP64-LABEL: @transfer_read_1d(
+// CMP64: %[[MEM:.*]]: memref<?xf32>, %[[OFF:.*]]: index) -> vector<16xf32> {
+// CMP64: %[[D:.*]] = memref.dim %[[MEM]], %{{.*}} : memref<?xf32>
+// CMP64: %[[S:.*]] = arith.subi %[[D]], %[[OFF]] : index
 // CMP64: %[[C:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi64>
-// CMP64: %[[A:.*]] = arith.addi %{{.*}}, %[[C]] : vector<16xi64>
-// CMP64: %[[M:.*]] = arith.cmpi slt, %[[A]], %{{.*}} : vector<16xi64>
+// CMP64: %[[B:.*]] = arith.index_cast %[[S]] : index to i64
+// CMP64: %[[B0:.*]] = llvm.insertelement %[[B]], %{{.*}} : vector<16xi64>
+// CMP64: %[[BV:.*]] = llvm.shufflevector %[[B0]], {{.*}} : vector<16xi64>, vector<16xi64>
+// CMP64: %[[M:.*]] = arith.cmpi slt, %[[C]], %[[BV]] : vector<16xi64>
 // CMP64: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %[[M]], %{{.*}}
 // CMP64: return %[[L]] : vector<16xf32>
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d703d5bdb0b1d..3dcbd3ae475e2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1245,34 +1245,33 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
   return %f: vector<17xf32>
 }
 // CHECK-LABEL: func @transfer_read_1d
-//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//       CHECK: %[[c7:.*]] = arith.constant 7.0
+//  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
+//  CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
+//       CHECK: %[[C7:.*]] = arith.constant 7.0
+//
+// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
 //       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
+//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
+//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
 //
-// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex:.*]] = arith.constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
 //  CHECK-SAME: vector<17xi32>
 //
-// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-//       CHECK: %[[otrunc:.*]] = arith.index_cast %[[BASE]] : index to i32
-//       CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[otrunc]]
-//       CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
-//       CHECK: %[[offsetVec2:.*]] = arith.addi %[[offsetVec]], %[[linearIndex]] : vector<17xi32>
-//
-// 3. Let dim the memref dimension, compute the vector comparison mask:
-//    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-//       CHECK: %[[dtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
-//       CHECK: %[[dimVecInsert:.*]] = llvm.insertelement %[[dtrunc]]
-//       CHECK: %[[dimVec:.*]] = llvm.shufflevector %[[dimVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[offsetVec2]], %[[dimVec]] : vector<17xi32>
+// 3. Create bound vector to compute in-bound mask:
+//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
+//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
+//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
+//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+//  CHECK-SAME: : vector<17xi32>
 //
 // 4. Create pass-through vector.
 //       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
 //
 // 5. Bitcast to vector form.
-//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+//       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
 //  CHECK-SAME: (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
 //       CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] :
 //  CHECK-SAME: !llvm.ptr<f32> to !llvm.ptr<vector<17xf32>>
@@ -1280,21 +1279,24 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 // 6. Rewrite as a masked read.
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
 //  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 4 : i32} :
-//  CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
 //
-// 1. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
+//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+//       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
+//
+// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex_b:.*]] = arith.constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
 //  CHECK-SAME: vector<17xi32>
 //
-// 2. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-//       CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
-//       CHECK: arith.addi
-//
-// 3. Let dim the memref dimension, compute the vector comparison mask:
-//    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-//       CHECK: llvm.shufflevector %{{.*}} : vector<17xi32>
-//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, {{.*}} : vector<17xi32>
+// 3. Create bound vector to compute in-bound mask:
+//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+//       CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]] : index to i32
+//       CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
+//       CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
+//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
+//  CHECK-SAME: %[[boundVect_b]] : vector<17xi32>
 //
 // 4. Bitcast to vector form.
 //       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
@@ -1344,16 +1346,20 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
 //       CHECK: %[[c1:.*]] = arith.constant 1 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
 //
-// Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
-//       CHECK: %[[trunc:.*]] = arith.index_cast %[[BASE_1]] : index to i32
-//       CHECK: %[[offsetVecInsert:.*]] = llvm.insertelement %[[trunc]]
-//       CHECK: %[[offsetVec:.*]] = llvm.shufflevector %[[offsetVecInsert]]
+// Compute the in-bound index (dim - offset)
+//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
+//
+// Create a vector with linear indices [ 0 .. vector_length - 1 ].
+//       CHECK: %[[linearIndex:.*]] = arith.constant dense
+//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME: vector<17xi32>
 //
-// Let dim the memref dimension, compute the vector comparison mask:
-//    [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
-//       CHECK: %[[dimtrunc:.*]] = arith.index_cast %[[DIM]] : index to i32
-//       CHECK: %[[dimtruncInsert:.*]] = llvm.insertelement %[[dimtrunc]]
-//       CHECK: llvm.shufflevector %[[dimtruncInsert]]
+// Create bound vector to compute in-bound mask:
+//    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
+//       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to i32
+//       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
+//       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
+//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
 
 // -----
 


        


More information about the Mlir-commits mailing list