[Mlir-commits] [mlir] adabce4 - Correctly model undefined behavior in {tensor|memref}.dim
Sanjoy Das
llvmlistbot at llvm.org
Wed Oct 12 17:30:32 PDT 2022
Author: Sanjoy Das
Date: 2022-10-12T17:30:13-07:00
New Revision: adabce41185910227ca276a1cfd22e76443dd238
URL: https://github.com/llvm/llvm-project/commit/adabce41185910227ca276a1cfd22e76443dd238
DIFF: https://github.com/llvm/llvm-project/commit/adabce41185910227ca276a1cfd22e76443dd238.diff
LOG: Correctly model undefined behavior in {tensor|memref}.dim
These operations have undefined behavior if the index is not less than the rank of the source tensor / memref, so they cannot be freely speculated like they were before this patch. After this patch we speculate them only if we can prove that they don't have UB.
Depends on D135505.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D135748
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Transforms/loop-invariant-code-motion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c94a531019209..54394dadafcad 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -544,7 +544,7 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
def MemRef_DimOp : MemRef_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
MemRefsNormalizable,
- Pure,
+ ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
let summary = "dimension index operation";
let description = [{
@@ -593,6 +593,9 @@ def MemRef_DimOp : MemRef_Op<"dim", [
/// Interface method of ShapedDimOpInterface: Return the dimension.
OpFoldResult getDimension() { return getIndex(); }
+
+ /// Interface method for ConditionallySpeculatable.
+ Speculation::Speculatability getSpeculatability();
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index bdc24fa0675e1..00887566812ef 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -87,7 +87,7 @@ def Tensor_CastOp : Tensor_Op<"cast", [
def Tensor_DimOp : Tensor_Op<"dim", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- Pure,
+ ConditionallySpeculatable, NoMemoryEffect,
ShapedDimOpInterface]> {
let summary = "dimension index operation";
let description = [{
@@ -135,6 +135,9 @@ def Tensor_DimOp : Tensor_Op<"dim", [
/// Interface method of ShapedDimOpInterface: Return the dimension.
OpFoldResult getDimension() { return getIndex(); }
+
+ /// Interface method for ConditionallySpeculatable.
+ Speculation::Speculatability getSpeculatability();
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 292eb4618aac6..fbc1eadffcba3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -819,6 +819,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
+Speculation::Speculatability DimOp::getSpeculatability() {
+ auto constantIndex = getConstantIndex();
+ if (!constantIndex)
+ return Speculation::NotSpeculatable;
+
+ auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
+ if (!rankedSourceType)
+ return Speculation::NotSpeculatable;
+
+ // The verifier rejects operations that violate this assertion.
+ assert(constantIndex < rankedSourceType.getRank());
+ return Speculation::Speculatable;
+}
+
LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = getConstantIndex();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 448e97cb97f21..0ee79a6ea268e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -328,6 +328,20 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
+Speculation::Speculatability DimOp::getSpeculatability() {
+ auto constantIndex = getConstantIndex();
+ if (!constantIndex)
+ return Speculation::NotSpeculatable;
+
+ auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
+ if (!rankedSourceType)
+ return Speculation::NotSpeculatable;
+
+ // The verifier rejects operations that violate this assertion.
+ assert(constantIndex < rankedSourceType.getRank());
+ return Speculation::Speculatable;
+}
+
LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = getConstantIndex();
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 0b74c81b6de13..b8d3450862a39 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -503,3 +503,107 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}
+
+// -----
+
+func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
+// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_unknown_dim
+ %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ // CHECK: scf.for
+ // CHECK-NEXT: tensor.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = tensor.dim %t, %dim_idx : tensor<*xf32>
+ }
+
+ return
+}
+
+func.func @speculate_tensor_dim_known_rank_unknown_dim(
+// CHECK-LABEL: @speculate_tensor_dim_known_rank_unknown_dim
+ %t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ // CHECK: scf.for
+ // CHECK-NEXT: tensor.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = tensor.dim %t, %dim_idx : tensor<?x?x?x?xf32>
+ }
+
+ return
+}
+
+func.func @speculate_tensor_dim_unknown_rank_known_dim(
+// CHECK-LABEL: @speculate_tensor_dim_unknown_rank_known_dim
+ %t: tensor<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : index
+ // CHECK: scf.for
+ // CHECK-NEXT: tensor.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = tensor.dim %t, %c0 : tensor<*xf32>
+ }
+
+ return
+}
+
+func.func @speculate_tensor_dim_known_rank_known_dim_inbounds(
+// CHECK-LABEL: @speculate_tensor_dim_known_rank_known_dim_inbounds
+ %t: tensor<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: tensor.dim
+ // CHECK-NEXT: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %val = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+ }
+
+ return
+}
+
+// -----
+
+func.func @speculate_memref_dim_unknown_rank_unknown_dim(
+// CHECK-LABEL: @speculate_memref_dim_unknown_rank_unknown_dim
+ %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ // CHECK: scf.for
+ // CHECK-NEXT: memref.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = memref.dim %t, %dim_idx : memref<*xf32>
+ }
+
+ return
+}
+
+func.func @speculate_memref_dim_known_rank_unknown_dim(
+// CHECK-LABEL: @speculate_memref_dim_known_rank_unknown_dim
+ %t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ // CHECK: scf.for
+ // CHECK-NEXT: memref.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = memref.dim %t, %dim_idx : memref<?x?x?x?xf32>
+ }
+
+ return
+}
+
+func.func @speculate_memref_dim_unknown_rank_known_dim(
+// CHECK-LABEL: @speculate_memref_dim_unknown_rank_known_dim
+ %t: memref<*xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ %c0 = arith.constant 0 : index
+ // CHECK: scf.for
+ // CHECK-NEXT: memref.dim
+ scf.for %i = %lb to %ub step %step {
+ %val = memref.dim %t, %c0 : memref<*xf32>
+ }
+
+ return
+}
+
+func.func @speculate_memref_dim_known_rank_known_dim_inbounds(
+// CHECK-LABEL: @speculate_memref_dim_known_rank_known_dim_inbounds
+ %t: memref<?x?x?x?xf32>, %dim_idx: index, %lb: index, %ub: index, %step: index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: memref.dim
+ // CHECK-NEXT: scf.for
+ scf.for %i = %lb to %ub step %step {
+ %val = memref.dim %t, %c1 : memref<?x?x?x?xf32>
+ }
+
+ return
+}
More information about the Mlir-commits
mailing list