[Mlir-commits] [mlir] f581e65 - [MLIR] Add std.assume_alignment op.
Tim Shen
llvmlistbot at llvm.org
Tue Feb 18 17:55:28 PST 2020
Author: Tim Shen
Date: 2020-02-18T17:55:07-08:00
New Revision: f581e655ec3f34dcd704ffc9586bfb615a459942
URL: https://github.com/llvm/llvm-project/commit/f581e655ec3f34dcd704ffc9586bfb615a459942
DIFF: https://github.com/llvm/llvm-project/commit/f581e655ec3f34dcd704ffc9586bfb615a459942.diff
LOG: [MLIR] Add std.assume_alignment op.
Reviewers: ftynse, nicolasvasilache, andydavis1
Subscribers: bixia, sanjoy.google, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74378
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ebba7a4a7aeb..d73fd1187431 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -870,4 +870,14 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">,
let verifier = "return ::verify(*this);";
}
+def LLVM_AssumeOp : LLVM_Op<"intr.assume", []>,
+ Arguments<(ins LLVM_Type:$cond)> {
+ let llvmBuilder = [{
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::Function *fn =
+ llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::assume, {});
+ builder.CreateCall(fn, {$cond});
+ }];
+}
+
#endif // LLVMIR_OPS
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 0f4a1c7a0ea0..73a88681da3c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1639,4 +1639,21 @@ def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]>
}];
}
+def AssumeAlignmentOp : Std_Op<"assume_alignment"> {
+ let summary =
+ "assertion that gives alignment information to the input memref";
+ let description = [{
+ The assume alignment operation takes a memref and a integer of alignment
+ value, and internally annotates the buffer with the given alignment. If
+ the buffer isn't aligned to the given alignment, the behavior is undefined.
+
+ This operation doesn't affect the semantics of a correct program. It's for
+ optimization only, and the optimization is best-effort.
+ }];
+ let arguments = (ins AnyMemRef:$memref, PositiveI32Attr:$alignment);
+ let results = (outs);
+
+ let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
+}
+
#endif // STANDARD_OPS
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 88ca986874e4..97a6d618dd82 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -2501,6 +2501,45 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
}
};
+struct AssumeAlignmentOpLowering
+ : public LLVMLegalizationPattern<AssumeAlignmentOp> {
+ using LLVMLegalizationPattern<AssumeAlignmentOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<AssumeAlignmentOp> transformed(operands);
+ Value memref = transformed.memref();
+ unsigned alignment = cast<AssumeAlignmentOp>(op).alignment().getZExtValue();
+
+ MemRefDescriptor memRefDescriptor(memref);
+ Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc());
+
+ // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that
+ // the asserted memref.alignedPtr isn't used anywhere else, as the real
+ // users like load/store/views always re-extract memref.alignedPtr as they
+ // get lowered.
+ //
+ // This relies on LLVM's CSE optimization (potentially after SROA), since
+ // after CSE all memref.alignedPtr instances get de-duplicated into the same
+ // pointer SSA value.
+ Value zero =
+ createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0);
+ Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(),
+ alignment - 1);
+ Value ptrValue =
+ rewriter.create<LLVM::PtrToIntOp>(op->getLoc(), getIndexType(), ptr);
+ rewriter.create<LLVM::AssumeOp>(
+ op->getLoc(),
+ rewriter.create<LLVM::ICmpOp>(
+ op->getLoc(), LLVM::ICmpPredicate::eq,
+ rewriter.create<LLVM::AndOp>(op->getLoc(), ptrValue, mask), zero));
+
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+};
+
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
@@ -2612,6 +2651,7 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
bool useAlloca) {
// clang-format off
patterns.insert<
+ AssumeAlignmentOpLowering,
DimOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 81f5d4a153f6..9a0bcb9d58a0 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -2764,6 +2764,17 @@ SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
return success();
}
+//===----------------------------------------------------------------------===//
+// AssumeAlignmentOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AssumeAlignmentOp op) {
+ unsigned alignment = op.alignment().getZExtValue();
+ if (!llvm::isPowerOf2_32(alignment))
+ return op.emitOpError("alignment must be power of 2");
+ return success();
+}
+
namespace {
/// Pattern to rewrite a subview op with constant size arguments.
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index 4c5a4a078a11..8839514937e0 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -855,3 +855,18 @@ module {
// CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float
// CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table
}
+
+// -----
+
+// CHECK-LABEL: func @assume_alignment
+func @assume_alignment(%0 : memref<4x4xf16>) {
+ // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">
+ // CHECK-NEXT: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK-NEXT: %[[MASK:.*]] = llvm.mlir.constant(15 : index) : !llvm.i64
+ // CHECK-NEXT: %[[INT:.*]] = llvm.ptrtoint %[[PTR]] : !llvm<"half*"> to !llvm.i64
+ // CHECK-NEXT: %[[MASKED_PTR:.*]] = llvm.and %[[INT]], %[[MASK:.*]] : !llvm.i64
+ // CHECK-NEXT: %[[CONDITION:.*]] = llvm.icmp "eq" %[[MASKED_PTR]], %[[ZERO]] : !llvm.i64
+ // CHECK-NEXT: "llvm.intr.assume"(%[[CONDITION]]) : (!llvm.i1) -> ()
+ assume_alignment %0, 16 : memref<4x4xf16>
+ return
+}
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 2318ef5e7eec..382b1602df0d 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -740,3 +740,11 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
tensor_store %1, %0 : memref<4x4xi32>
return
}
+
+// CHECK-LABEL: func @assume_alignment
+// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
+func @assume_alignment(%0: memref<4x4xf16>) {
+ // CHECK: assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
+ assume_alignment %0, 16 : memref<4x4xf16>
+ return
+}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index d96f646e60a7..9f48d9d6bc70 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -1036,3 +1036,21 @@ func @invalid_memref_cast() {
%2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0>
return
}
+
+// -----
+
+// alignment is not power of 2.
+func @assume_alignment(%0: memref<4x4xf16>) {
+ // expected-error at +1 {{alignment must be power of 2}}
+ std.assume_alignment %0, 12 : memref<4x4xf16>
+ return
+}
+
+// -----
+
+// 0 alignment value.
+func @assume_alignment(%0: memref<4x4xf16>) {
+ // expected-error at +1 {{'std.assume_alignment' op attribute 'alignment' failed to satisfy constraint: positive 32-bit integer attribute}}
+ std.assume_alignment %0, 0 : memref<4x4xf16>
+ return
+}
More information about the Mlir-commits
mailing list