[Mlir-commits] [mlir] f581e65 - [MLIR] Add std.assume_alignment op.

Tim Shen llvmlistbot at llvm.org
Tue Feb 18 17:55:28 PST 2020


Author: Tim Shen
Date: 2020-02-18T17:55:07-08:00
New Revision: f581e655ec3f34dcd704ffc9586bfb615a459942

URL: https://github.com/llvm/llvm-project/commit/f581e655ec3f34dcd704ffc9586bfb615a459942
DIFF: https://github.com/llvm/llvm-project/commit/f581e655ec3f34dcd704ffc9586bfb615a459942.diff

LOG: [MLIR] Add std.assume_alignment op.

Reviewers: ftynse, nicolasvasilache, andydavis1

Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74378

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/StandardOps/Ops.td
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/Ops.cpp
    mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
    mlir/test/IR/core-ops.mlir
    mlir/test/IR/invalid-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ebba7a4a7aeb..d73fd1187431 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -870,4 +870,14 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">,
   let verifier = "return ::verify(*this);";
 }
 
+def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>,
+                    Arguments<(ins LLVM_Type:$cond)> {
+  let llvmBuilder = [{
+    llvm::Module *module = builder.GetInsertBlock()->getModule();
+    llvm::Function *fn =
+        llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {});
+    builder.CreateCall(fn, {$cond});
+  }];
+}
+
 #endif // LLVMIR_OPS

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 0f4a1c7a0ea0..73a88681da3c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1639,4 +1639,21 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
   }];
 }
 
+def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
+  let summary =
+      "assertion that gives alignment information to the input memref";
+  let description = [{
+    The assume alignment operation takes a memref and a integer of alignment
+    value, and internally annotates the buffer with the given alignment. If
+    the buffer isn't aligned to the given alignment, the behavior is undefined.
+
+    This operation doesn't affect the semantics of a correct program. It's for
+    optimization only, and the optimization is best-effort.
+  }];
+  let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
+  let results = (outs);
+
+  let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
+}
+
 #endif // STANDARD_OPS

diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 88ca986874e4..97a6d618dd82 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -2501,6 +2501,45 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
   }
 };
 
+struct AssumeAlignmentOpLowering
+    : public LLVMLegalizationPattern<AssumeAlignmentOp> {
+  using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    OperandAdaptor<AssumeAlignmentOp> transformed(operands);
+    Value memref = transformed.memref();
+    unsigned alignment = cast<AssumeAlignmentOp>(op).alignment().getZExtValue();
+
+    MemRefDescriptor memRefDescriptor(memref);
+    Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
+
+    // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
+    // the asserted memref.alignedPtr isn't used anywhere else, as the real
+    // users like load/store/views always re-extract memref.alignedPtr as they
+    // get lowered.
+    //
+    // This relies on LLVM's CSE optimization (potentially after SROA), since
+    // after CSE all memref.alignedPtr instances get de-duplicated into the same
+    // pointer SSA value.
+    Value zero =
+        createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0);
+    Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(),
+                                         alignment - 1);
+    Value ptrValue =
+        rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), ptr);
+    rewriter.create<LLVM::AssumeOp>(
+        op->getLoc(),
+        rewriter.create<LLVM::ICmpOp>(
+            op->getLoc(), LLVM::ICmpPredicate::eq,
+            rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
+
+    rewriter.eraseOp(op);
+    return matchSuccess();
+  }
+};
+
 } // namespace
 
 static void ensureDistinctSuccessors(Block &bb) {
@@ -2612,6 +2651,7 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
     bool useAlloca) {
   // clang-format off
   patterns.insert<
+      AssumeAlignmentOpLowering,
       DimOpLowering,
       LoadOpLowering,
       MemRefCastOpLowering,

diff  --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 81f5d4a153f6..9a0bcb9d58a0 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2764,6 +2764,17 @@ SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// AssumeAlignmentOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AssumeAlignmentOp op) {
+  unsigned alignment = op.alignment().getZExtValue();
+  if (!llvm::isPowerOf2_32(alignment))
+    return op.emitOpError("alignment must be power of 2");
+  return success();
+}
+
 namespace {
 
 /// Pattern to rewrite a subview op with constant size arguments.

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 4c5a4a078a11..8839514937e0 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -855,3 +855,18 @@ module {
 // CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
 // CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
 }
+
+// -----
+
+// CHECK-LABEL: func @assume_alignment
+func @assume_alignment(%0 : memref<4x4xf16>) {
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+  // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
+  // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64
+  // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64
+  // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64
+  // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> ()
+  assume_alignment %0, 16 : memref<4x4xf16>
+  return
+}

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 2318ef5e7eec..382b1602df0d 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -740,3 +740,11 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
   tensor_store %1, %0 : memref<4x4xi32>
   return
 }
+
+// CHECK-LABEL: func @assume_alignment
+// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
+func @assume_alignment(%0: memref<4x4xf16>) {
+  // CHECK: assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
+  assume_alignment %0, 16 : memref<4x4xf16>
+  return
+}

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index d96f646e60a7..9f48d9d6bc70 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1036,3 +1036,21 @@ func @invalid_memref_cast() {
   %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
   return
 }
+
+// -----
+
+// alignment is not power of 2.
+func @assume_alignment(%0: memref<4x4xf16>) {
+  // expected-error at +1 {{alignment must be power of 2}}
+  std.assume_alignment %0, 12 : memref<4x4xf16>
+  return
+}
+
+// -----
+
+// 0 alignment value.
+func @assume_alignment(%0: memref<4x4xf16>) {
+  // expected-error at +1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: positive 32-bit integer attribute}}
+  std.assume_alignment %0, 0 : memref<4x4xf16>
+  return
+}


        


More information about the Mlir-commits mailing list