[Mlir-commits] [mlir] affbc0c - [mlir] Add alignment attribute to LLVM memory ops and use in vector.transfer

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 13 14:38:20 PDT 2020


Author: Nicolas Vasilache
Date: 2020-07-13T17:35:20-04:00
New Revision: affbc0cd1cc87826c2636f8903d85c911aef75ff

URL: https://github.com/llvm/llvm-project/commit/affbc0cd1cc87826c2636f8903d85c911aef75ff
DIFF: https://github.com/llvm/llvm-project/commit/affbc0cd1cc87826c2636f8903d85c911aef75ff.diff

LOG: [mlir] Add alignment attribute to LLVM memory ops and use in vector.transfer

Summary: The native alignment may generally not be used when lowering a vector.transfer to the underlying load/store operation. This revision fixes the unmasked load/store alignment to match that of the masked path.

Differential Revision: https://reviews.llvm.org/D83684

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir
    mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 663a820905ce..ce0b3de82d2c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -215,19 +215,36 @@ def LLVM_FDivOp : LLVM_ArithmeticOp<"fdiv", "CreateFDiv">;
 def LLVM_FRemOp : LLVM_ArithmeticOp<"frem", "CreateFRem">;
 def LLVM_FNegOp : LLVM_UnaryArithmeticOp<"fneg", "CreateFNeg">;
 
+// Common code definition that is used to verify and set the alignment attribute
+// of LLVM ops that accept such an attribute.
+class MemoryOpWithAlignmentBase {
+  code alignmentVerifierCode = [{
+    if (alignment().hasValue()) {
+      auto align = alignment().getValue().getSExtValue();
+      if (align < 0)
+        return emitOpError("expected positive alignment");
+    }
+    return success();
+  }];
+  code setAlignmentCode = [{
+    if ($alignment.hasValue()) {
+      auto align = $alignment.getValue().getZExtValue();
+      if (align != 0)
+        inst->setAlignment(llvm::Align(align));
+    }
+  }];
+}
+
 // Memory-related operations.
 def LLVM_AllocaOp :
+    MemoryOpWithAlignmentBase,
     LLVM_OneResultOp<"alloca">,
     Arguments<(ins LLVM_Type:$arraySize, OptionalAttr<I64Attr>:$alignment)> {
   string llvmBuilder = [{
-    auto *alloca = builder.CreateAlloca(
+    auto *inst = builder.CreateAlloca(
       $_resultType->getPointerElementType(), $arraySize);
-    if ($alignment.hasValue()) {
-      auto align = $alignment.getValue().getZExtValue();
-      if (align != 0)
-        alloca->setAlignment(llvm::Align(align));
-    }
-    $res = alloca;
+    }] # setAlignmentCode # [{
+    $res = inst;
   }];
   let builders = [OpBuilder<
     "OpBuilder &b, OperationState &result, Type resultType, Value arraySize, "
@@ -239,14 +256,7 @@ def LLVM_AllocaOp :
   }]>];
   let parser = [{ return parseAllocaOp(parser, result); }];
   let printer = [{ printAllocaOp(p, *this); }];
-  let verifier = [{
-    if (alignment().hasValue()) {
-      auto align = alignment().getValue().getSExtValue();
-      if (align < 0)
-        return emitOpError("expected positive alignment");
-    }
-    return success();
-  }];
+  let verifier = alignmentVerifierCode;
 }
 def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
                  Arguments<(ins LLVM_Type:$base, Variadic<LLVM_Type>:$indices)>,
@@ -255,22 +265,56 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
     $base `[` $indices `]` attr-dict `:` functional-type(operands, results)
   }];
 }
-def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
-                  LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
+def LLVM_LoadOp :
+    MemoryOpWithAlignmentBase,
+    LLVM_OneResultOp<"load">,
+    Arguments<(ins LLVM_Type:$addr, OptionalAttr<I64Attr>:$alignment)> {
+  string llvmBuilder = [{
+    auto *inst = builder.CreateLoad($addr);
+  }] # setAlignmentCode # [{
+    $res = inst;
+  }];
   let builders = [OpBuilder<
-    "OpBuilder &b, OperationState &result, Value addr",
+    "OpBuilder &b, OperationState &result, Value addr, unsigned alignment = 0",
     [{
       auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
-      build(b, result, type, addr);
+      build(b, result, type, addr, alignment);
+    }]>,
+    OpBuilder<
+    "OpBuilder &b, OperationState &result, Type t, Value addr, "
+    "unsigned alignment = 0",
+    [{
+      if (alignment == 0)
+        return build(b, result, t, addr, IntegerAttr());
+      build(b, result, t, addr, b.getI64IntegerAttr(alignment));
     }]>];
   let parser = [{ return parseLoadOp(parser, result); }];
   let printer = [{ printLoadOp(p, *this); }];
+  let verifier = alignmentVerifierCode;
 }
-def LLVM_StoreOp : LLVM_ZeroResultOp<"store">,
-                   Arguments<(ins LLVM_Type:$value, LLVM_Type:$addr)>,
-                   LLVM_Builder<"builder.CreateStore($value, $addr);"> {
+def LLVM_StoreOp :
+    MemoryOpWithAlignmentBase,
+    LLVM_ZeroResultOp<"store">,
+    Arguments<(ins LLVM_Type:$value,
+                   LLVM_Type:$addr,
+                   OptionalAttr<I64Attr>:$alignment)> {
+  string llvmBuilder = [{
+    auto *inst = builder.CreateStore($value, $addr);
+  }] # setAlignmentCode;
+  let builders = [
+    OpBuilder<
+    "OpBuilder &b, OperationState &result, Value value, Value addr, "
+    "unsigned alignment = 0",
+    [{
+      if (alignment == 0)
+        return build(b, result, ArrayRef<Type>{}, value, addr, IntegerAttr());
+      build(b, result, ArrayRef<Type>{}, value, addr, 
+            b.getI64IntegerAttr(alignment));
+    }]
+  >];
   let parser = [{ return parseStoreOp(parser, result); }];
   let printer = [{ printStoreOp(p, *this); }];
+  let verifier = alignmentVerifierCode;
 }
 
 // Casts.

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir
index f8934f06c0fd..e6fa0df1ed7e 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-read.mlir
@@ -12,6 +12,15 @@ func @transfer_read_1d(%A : memref<?xf32>, %base: index) {
   return
 }
 
+func @transfer_read_unmasked_4(%A : memref<?xf32>, %base: index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base], %fm42
+      {permutation_map = affine_map<(d0) -> (d0)>, masked = [false]} :
+    memref<?xf32>, vector<4xf32>
+  vector.print %f: vector<4xf32>
+  return
+}
+
 func @transfer_write_1d(%A : memref<?xf32>, %base: index) {
   %f0 = constant 0.0 : f32
   %vf0 = splat %f0 : vector<4xf32>
@@ -44,8 +53,12 @@ func @entry() {
   // Read shifted by 0 and pad with -42:
   //   ( 0, 1, 2, 0, 0, -42, ..., -42)
   call @transfer_read_1d(%A, %c0) : (memref<?xf32>, index) -> ()
+  // Read unmasked 4 @ 1, guaranteed to not overflow.
+  // Exercises proper alignment.
+  call @transfer_read_unmasked_4(%A, %c1) : (memref<?xf32>, index) -> ()
   return
 }
 
 // CHECK: ( 2, 3, 4, -42, -42, -42, -42, -42, -42, -42, -42, -42, -42 )
 // CHECK: ( 0, 1, 2, 0, 0, -42, -42, -42, -42, -42, -42, -42, -42 )
+// CHECK: ( 1, 2, 0, 0 )

diff  --git a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir
index 57163700fc99..c61a1629dcfb 100644
--- a/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir
+++ b/mlir/integration_test/Dialect/Vector/CPU/test-transfer-write.mlir
@@ -3,11 +3,11 @@
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
 // RUN: FileCheck %s
 
-func @transfer_write16_1d(%A : memref<?xf32>, %base: index) {
+func @transfer_write16_unmasked_1d(%A : memref<?xf32>, %base: index) {
   %f = constant 16.0 : f32
   %v = splat %f : vector<16xf32>
   vector.transfer_write %v, %A[%base]
-    {permutation_map = affine_map<(d0) -> (d0)>}
+    {permutation_map = affine_map<(d0) -> (d0)>, masked = [false]}
     : vector<16xf32>, memref<?xf32>
   return
 }
@@ -53,14 +53,14 @@ func @entry() {
   %0 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
   vector.print %0 : vector<32xf32>
 
-  // Overwrite with 16 values of 16 at base 4.
-  %c4 = constant 4: index
-  call @transfer_write16_1d(%A, %c4) : (memref<?xf32>, index) -> ()
+  // Overwrite with 16 values of 16 at base 3.
+  // Statically guaranteed to be unmasked. Exercises proper alignment.
+  %c3 = constant 3: index
+  call @transfer_write16_unmasked_1d(%A, %c3) : (memref<?xf32>, index) -> ()
   %1 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
   vector.print %1 : vector<32xf32>
 
   // Overwrite with 13 values of 13 at base 3.
-  %c3 = constant 3: index
   call @transfer_write13_1d(%A, %c3) : (memref<?xf32>, index) -> ()
   %2 = call @transfer_read_1d(%A) : (memref<?xf32>) -> (vector<32xf32>)
   vector.print %2 : vector<32xf32>
@@ -93,8 +93,8 @@ func @entry() {
 }
 
 // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
-// CHECK: ( 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
-// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+// CHECK: ( 0, 0, 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
+// CHECK: ( 0, 0, 0, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
 // CHECK: ( 0, 0, 0, 17, 17, 17, 17, 17, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 0 )

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 2be2bd9bb7d0..a59f02681c54 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -143,7 +143,10 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  LLVMTypeConverter &typeConverter, Location loc,
                                  TransferReadOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
-  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr);
+  unsigned align;
+  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+    return failure();
+  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(xferOp, dataPtr, align);
   return success();
 }
 
@@ -176,8 +179,12 @@ replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  LLVMTypeConverter &typeConverter, Location loc,
                                  TransferWriteOp xferOp,
                                  ArrayRef<Value> operands, Value dataPtr) {
+  unsigned align;
+  if (failed(getVectorTransferAlignment(typeConverter, xferOp, align)))
+    return failure();
   auto adaptor = TransferWriteOpAdaptor(operands);
-  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr);
+  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(xferOp, adaptor.vector(), dataPtr,
+                                             align);
   return success();
 }
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 829edf5f66f1..874cb5cca141 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -935,7 +935,7 @@ func @transfer_read_1d_not_masked(%A : memref<?xf32>, %base: index) -> vector<17
 //  CHECK-SAME: !llvm<"float*"> to !llvm<"<17 x float>*">
 //
 // 2. Rewrite as a load.
-//       CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] : !llvm<"<17 x float>*">
+//       CHECK: %[[loaded:.*]] = llvm.load %[[vecPtr]] {alignment = 4 : i64} : !llvm<"<17 x float>*">
 
 func @genbool_1d() -> vector<8xi1> {
   %0 = vector.constant_mask [4] : vector<8xi1>


        


More information about the Mlir-commits mailing list