[Mlir-commits] [mlir] 29a50c5 - [MLIR] Update Vector To LLVM conversion to be aware of assume_alignment

Stephen Neuendorffer llvmlistbot at llvm.org
Wed May 19 10:50:58 PDT 2021


Author: Stephen Neuendorffer
Date: 2021-05-19T10:50:48-07:00
New Revision: 29a50c5864ddab283c1ff38694fb5926ce37b39a

URL: https://github.com/llvm/llvm-project/commit/29a50c5864ddab283c1ff38694fb5926ce37b39a
DIFF: https://github.com/llvm/llvm-project/commit/29a50c5864ddab283c1ff38694fb5926ce37b39a.diff

LOG: [MLIR] Update Vector To LLVM conversion to be aware of assume_alignment

vector.transfer_read and vector.transfer_write operations are converted
to llvm intrinsics with specific alignment information, however there
doesn't seem to be a way in llvm to take information from llvm.assume
intrinsics and change this alignment information.  In any
event, due the to the structure of the llvm.assume instrinsic, applying
this information at the llvm level is more cumbersome.  Instead, let's
generate the masked vector load and store instrinsic with the right
alignment information from MLIR in the first place.  Since
we're bothering to do this, lets just emit the proper alignment for
loads, stores, scatter, and gather ops too.

Differential Revision: https://reviews.llvm.org/D100444

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5665591583854..f3909b3e85c6b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Target/LLVMIR/TypeTranslation.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -119,6 +120,42 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
+// Return the minimal alignment value that satisfies all the AssumeAlignment
+// uses of `value`. If no such uses exist, return 1.
+static unsigned getAssumedAlignment(Value value) {
+  unsigned align = 1;
+  for (auto &u : value.getUses()) {
+    Operation *owner = u.getOwner();
+    if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
+      align = mlir::lcm(align, op.alignment());
+  }
+  return align;
+}
+// Helper that returns data layout alignment of a memref associated with a
+// transfer op, including additional information from assume_alignment calls
+// on the source of the transfer
+LogicalResult getTransferOpAlignment(LLVMTypeConverter &typeConverter,
+                                     VectorTransferOpInterface xfer,
+                                     unsigned &align) {
+  if (failed(getMemRefAlignment(
+          typeConverter, xfer.getShapedType().cast<MemRefType>(), align)))
+    return failure();
+  align = std::max(align, getAssumedAlignment(xfer.source()));
+  return success();
+}
+
+// Helper that returns data layout alignment of a memref associated with a
+// load, store, scatter, or gather op, including additional information from
+// assume_alignment calls on the source of the transfer
+template <class OpAdaptor>
+LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
+                                   OpAdaptor op, unsigned &align) {
+  if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
+    return failure();
+  align = std::max(align, getAssumedAlignment(op.base()));
+  return success();
+}
+
 // Add an index vector component to a base pointer. This almost always succeeds
 // unless the last stride is non-unit or the memory space is not zero.
 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
@@ -151,8 +188,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  TransferReadOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getMemRefAlignment(
-          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
+  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
     return failure();
   rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
   return success();
@@ -171,10 +207,8 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
   Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding());
 
   unsigned align;
-  if (failed(getMemRefAlignment(
-          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
+  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
     return failure();
-
   rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
       xferOp, vecTy, dataPtr, mask, ValueRange{fill},
       rewriter.getI32IntegerAttr(align));
@@ -187,8 +221,7 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  TransferWriteOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
   unsigned align;
-  if (failed(getMemRefAlignment(
-          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
+  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
     return failure();
   auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
@@ -202,8 +235,7 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
                             TransferWriteOp xferOp, ArrayRef<Value> operands,
                             Value dataPtr, Value mask) {
   unsigned align;
-  if (failed(getMemRefAlignment(
-          typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
+  if (failed(getTransferOpAlignment(typeConverter, xferOp, align)))
     return failure();
 
   auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary());
@@ -337,7 +369,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+    if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
+                                    align)))
       return failure();
 
     // Resolve address.
@@ -367,7 +400,7 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
+    if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
       return failure();
 
     // Resolve address.
@@ -402,7 +435,7 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
+    if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
       return failure();
 
     // Resolve address.

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7bb3ab0e27aea..35aea5f841910 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1295,6 +1295,26 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
 
 // -----
 
+func @transfer_read_1d_aligned(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
+  memref.assume_alignment %A, 32 : memref<?xf32>
+  %f7 = constant 7.0: f32
+  %f = vector.transfer_read %A[%base], %f7
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32>, vector<17xf32>
+  vector.transfer_write %f, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<17xf32>, memref<?xf32>
+  return %f: vector<17xf32>
+}
+//       CHECK: llvm.intr.masked.load
+//  CHECK-SAME: {alignment = 32 : i32}
+//  CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
+//       CHECK: llvm.intr.masked.store
+//  CHECK-SAME: {alignment = 32 : i32}
+//  CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr<vector<17xf32>>
+
+// -----
+
 func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
   %f7 = constant 7.0: f32
   %f = vector.transfer_read %A[%base0, %base1], %f7
@@ -1487,6 +1507,22 @@ func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : ind
 
 // -----
 
+func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  memref.assume_alignment %memref, 32 : memref<200x100xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @vector_load_op_aligned
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
+// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<8xf32>>
+
+// -----
+
 func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
   %val = constant dense<11.0> : vector<4xf32>
   vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
@@ -1513,6 +1549,23 @@ func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : in
 
 // -----
 
+func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  memref.assume_alignment %memref, 32 : memref<200x100xf32>
+  %val = constant dense<11.0> : vector<4xf32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op_aligned
+// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
+// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
+// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
+// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
+// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<4xf32>>
+
+// -----
+
 func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
@@ -1590,6 +1643,20 @@ func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vec
 
 // -----
 
+func @gather_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  memref.assume_alignment %arg0, 32 : memref<?xf32>
+  %0 = constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_op_aligned
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: return %[[G]] : vector<3xf32>
+
+// -----
+
 func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@@ -1628,6 +1695,19 @@ func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: ve
 
 // -----
 
+func @scatter_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+  memref.assume_alignment %arg0, 32 : memref<?xf32>
+  %0 = constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_op_aligned
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
+
+// -----
+
 func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
   %0 = constant 3 : index
   vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>


        


More information about the Mlir-commits mailing list