[Mlir-commits] [mlir] 7386364 - Revert "[MLIR] Update Vector To LLVM conversion to be aware of assume_alignment"

Stephen Neuendorffer llvmlistbot at llvm.org
Tue Nov 30 15:18:34 PST 2021


Author: Stephen Neuendorffer
Date: 2021-11-30T15:18:22-08:00
New Revision: 73863648892ee7063c7fd4e658d7614fd721504a

URL: https://github.com/llvm/llvm-project/commit/73863648892ee7063c7fd4e658d7614fd721504a
DIFF: https://github.com/llvm/llvm-project/commit/73863648892ee7063c7fd4e658d7614fd721504a.diff

LOG: Revert "[MLIR] Update Vector To LLVM conversion to be aware of assume_alignment"

This reverts commit 29a50c5864ddab283c1ff38694fb5926ce37b39a.

After LLVM lowering, the original patch incorrectly moved alignment
information across an unconstrained GEP operation.  This is only correct
for some index offsets in the GEP.  It seems that the best approach is,
in fact, to rely on LLVM to propagate information from the llvm.assume()
to users.

Thanks to Thomas Raoux for catching this.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5ea34d03bec79..bc42922a44858 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -84,30 +84,6 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   return success();
 }
 
-// Return the minimal alignment value that satisfies all the AssumeAlignment
-// uses of `value`. If no such uses exist, return 1.
-static unsigned getAssumedAlignment(Value value) {
-  unsigned align = 1;
-  for (auto &u : value.getUses()) {
-    Operation *owner = u.getOwner();
-    if (auto op = dyn_cast<memref::AssumeAlignmentOp>(owner))
-      align = mlir::lcm(align, op.alignment());
-  }
-  return align;
-}
-
-// Helper that returns data layout alignment of a memref associated with a
-// load, store, scatter, or gather op, including additional information from
-// assume_alignment calls on the source of the transfer
-template <class OpAdaptor>
-LogicalResult getMemRefOpAlignment(LLVMTypeConverter &typeConverter,
-                                   OpAdaptor op, unsigned &align) {
-  if (failed(getMemRefAlignment(typeConverter, op.getMemRefType(), align)))
-    return failure();
-  align = std::max(align, getAssumedAlignment(op.base()));
-  return success();
-}
-
 // Add an index vector component to a base pointer. This almost always succeeds
 // unless the last stride is non-unit or the memory space is not zero.
 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
@@ -246,8 +222,7 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefOpAlignment(*this->getTypeConverter(), loadOrStoreOp,
-                                    align)))
+    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
       return failure();
 
     // Resolve address.
@@ -276,7 +251,7 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefOpAlignment(*getTypeConverter(), gather, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.
@@ -310,7 +285,7 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefOpAlignment(*getTypeConverter(), scatter, align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
     // Resolve address.

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 42a264a8aa972..ce81c4e36bb63 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1293,26 +1293,6 @@ func @transfer_read_index_1d(%A : memref<?xindex>, %base: index) -> vector<17xin
 
 // -----
 
-func @transfer_read_1d_aligned(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
-  memref.assume_alignment %A, 32 : memref<?xf32>
-  %f7 = arith.constant 7.0: f32
-  %f = vector.transfer_read %A[%base], %f7
-      {permutation_map = affine_map<(d0) -> (d0)>} :
-    memref<?xf32>, vector<17xf32>
-  vector.transfer_write %f, %A[%base]
-      {permutation_map = affine_map<(d0) -> (d0)>} :
-    vector<17xf32>, memref<?xf32>
-  return %f: vector<17xf32>
-}
-//       CHECK: llvm.intr.masked.load
-//  CHECK-SAME: {alignment = 32 : i32}
-//  CHECK-SAME: (!llvm.ptr<vector<17xf32>>, vector<17xi1>, vector<17xf32>) -> vector<17xf32>
-//       CHECK: llvm.intr.masked.store
-//  CHECK-SAME: {alignment = 32 : i32}
-//  CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr<vector<17xf32>>
-
-// -----
-
 func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {
   %f7 = arith.constant 7.0: f32
   %f = vector.transfer_read %A[%base0, %base1], %f7
@@ -1485,22 +1465,6 @@ func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : ind
 
 // -----
 
-func @vector_load_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
-  memref.assume_alignment %memref, 32 : memref<200x100xf32>
-  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
-  return %0 : vector<8xf32>
-}
-
-// CHECK-LABEL: func @vector_load_op_aligned
-// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
-// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
-// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
-// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<8xf32>>
-// CHECK: llvm.load %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<8xf32>>
-
-// -----
-
 func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) {
   %val = arith.constant dense<11.0> : vector<4xf32>
   vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
@@ -1527,23 +1491,6 @@ func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : in
 
 // -----
 
-func @vector_store_op_aligned(%memref : memref<200x100xf32>, %i : index, %j : index) {
-  memref.assume_alignment %memref, 32 : memref<200x100xf32>
-  %val = arith.constant dense<11.0> : vector<4xf32>
-  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
-  return
-}
-
-// CHECK-LABEL: func @vector_store_op_aligned
-// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64
-// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]]  : i64
-// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}}  : i64
-// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-// CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr<f32> to !llvm.ptr<vector<4xf32>>
-// CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 32 : i64} : !llvm.ptr<vector<4xf32>>
-
-// -----
-
 func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
@@ -1621,20 +1568,6 @@ func @gather_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: vec
 
 // -----
 
-func @gather_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
-  memref.assume_alignment %arg0, 32 : memref<?xf32>
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
-  return %1 : vector<3xf32>
-}
-
-// CHECK-LABEL: func @gather_op_aligned
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.vec<3 x ptr<f32>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: return %[[G]] : vector<3xf32>
-
-// -----
-
 func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@@ -1673,19 +1606,6 @@ func @scatter_op_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: ve
 
 // -----
 
-func @scatter_op_aligned(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
-  memref.assume_alignment %arg0, 32 : memref<?xf32>
-  %0 = arith.constant 0: index
-  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
-  return
-}
-
-// CHECK-LABEL: func @scatter_op_aligned
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<f32>, vector<3xi32>) -> !llvm.vec<3 x ptr<f32>>
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr<f32>>
-
-// -----
-
 func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) {
   %0 = arith.constant 3 : index
   vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32>


        


More information about the Mlir-commits mailing list