[Mlir-commits] [mlir] a23f190 - [mlir][vector] set alignment when lowering transfer_read and transfer_write.

Alex Zinenko llvmlistbot at llvm.org
Thu May 7 02:44:34 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-05-07T11:44:25+02:00
New Revision: a23f190213e16ec0f9075e1a813a046730f73458

URL: https://github.com/llvm/llvm-project/commit/a23f190213e16ec0f9075e1a813a046730f73458
DIFF: https://github.com/llvm/llvm-project/commit/a23f190213e16ec0f9075e1a813a046730f73458.diff

LOG: [mlir][vector] set alignment when lowering transfer_read and transfer_write.

When emitting masked load / store, set alignment from data layout.

Differential Revision: https://reviews.llvm.org/D79246

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index dec932173c45..423b93b42e00 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -752,6 +752,19 @@ void replaceTransferOp(ConversionPatternRewriter &rewriter,
                        Operation *op, ArrayRef<Value> operands, Value dataPtr,
                        Value mask);
 
+LogicalResult getLLVMTypeAndAlignment(LLVMTypeConverter &typeConverter,
+                                      Type type, LLVM::LLVMType &llvmType,
+                                      unsigned &align) {
+  auto convertedType = typeConverter.convertType(type);
+  if (!convertedType)
+    return failure();
+
+  llvmType = convertedType.template cast<LLVM::LLVMType>();
+  auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
+  align = dataLayout.getPrefTypeAlignment(llvmType.getUnderlyingType());
+  return success();
+}
+
 template <>
 void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
                                        LLVMTypeConverter &typeConverter,
@@ -764,10 +777,13 @@ void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
   Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
   fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
 
-  auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
-  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
-      op, vecTy, dataPtr, mask, ValueRange{fill},
-      rewriter.getI32IntegerAttr(1));
+  LLVM::LLVMType vecTy;
+  unsigned align;
+  if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
+                                        vecTy, align)))
+    rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
+        op, vecTy, dataPtr, mask, ValueRange{fill},
+        rewriter.getI32IntegerAttr(align));
 }
 
 template <>
@@ -777,8 +793,14 @@ void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
                                         ArrayRef<Value> operands, Value dataPtr,
                                         Value mask) {
   auto adaptor = TransferWriteOpOperandAdaptor(operands);
-  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
-      op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
+
+  auto xferOp = cast<TransferWriteOp>(op);
+  LLVM::LLVMType vecTy;
+  unsigned align;
+  if (succeeded(getLLVMTypeAndAlignment(typeConverter, xferOp.getVectorType(),
+                                        vecTy, align)))
+    rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
+        op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align));
 }
 
 static TransferReadOpOperandAdaptor

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 4d42b8d9b570..bd42b73f4496 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -818,7 +818,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //       CHECK: %[[PASS_THROUGH:.*]] =  llvm.mlir.constant(dense<7.000000e+00> :
 //  CHECK-SAME:  vector<17xf32>) : !llvm<"<17 x float>">
 //       CHECK: %[[loaded:.*]] = llvm.intr.masked.load %[[vecPtr]], %[[mask]],
-//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 1 : i32} :
+//  CHECK-SAME: %[[PASS_THROUGH]] {alignment = 128 : i32} :
 //  CHECK-SAME: (!llvm<"<17 x float>*">, !llvm<"<17 x i1>">, !llvm<"<17 x float>">) -> !llvm<"<17 x float>">
 
 //
@@ -850,7 +850,7 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) -> vector<17xf32> {
 //
 // 5. Rewrite as a masked write.
 //       CHECK: llvm.intr.masked.store %[[loaded]], %[[vecPtr_b]], %[[mask_b]]
-//  CHECK-SAME: {alignment = 1 : i32} :
+//  CHECK-SAME: {alignment = 128 : i32} :
 //  CHECK-SAME: !llvm<"<17 x float>">, !llvm<"<17 x i1>"> into !llvm<"<17 x float>*">
 
 func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<17xf32> {


        


More information about the Mlir-commits mailing list