[Mlir-commits] [mlir] a6a583d - [MLIR] Move AtomicRMW into MemRef dialect and enum into Arith
William S. Moses
llvmlistbot at llvm.org
Thu Dec 30 11:31:38 PST 2021
Author: William S. Moses
Date: 2021-12-30T14:31:33-05:00
New Revision: a6a583dae40485cacfac56811e6d9131bac6ca74
URL: https://github.com/llvm/llvm-project/commit/a6a583dae40485cacfac56811e6d9131bac6ca74
DIFF: https://github.com/llvm/llvm-project/commit/a6a583dae40485cacfac56811e6d9131bac6ca74.diff
LOG: [MLIR] Move AtomicRMW into MemRef dialect and enum into Arith
Per the discussion in https://reviews.llvm.org/D116345 it makes sense
to move AtomicRMWOp out of the standard dialect. This was accentuated by the
need to add a fold op with a memref::cast. The only dialect
that would permit this is the memref dialect (keeping it in the standard dialect
or moving it to the arithmetic dialect would require those dialects to have a
dependency on the memref dialect, which breaks linking).
As the AtomicRMWKind enum is used throughout, this has been moved to Arith.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D116392
Added:
Modified:
mlir/include/mlir/Analysis/AffineAnalysis.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/invalid.mlir
mlir/test/Dialect/MemRef/ops.mlir
mlir/test/Dialect/Standard/expand-ops.mlir
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
Removed:
mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
################################################################################
diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h
index 120a5be4596ac..fa793a9e17f84 100644
--- a/mlir/include/mlir/Analysis/AffineAnalysis.h
+++ b/mlir/include/mlir/Analysis/AffineAnalysis.h
@@ -15,6 +15,7 @@
#ifndef MLIR_ANALYSIS_AFFINE_ANALYSIS_H
#define MLIR_ANALYSIS_AFFINE_ANALYSIS_H
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/Optional.h"
@@ -32,7 +33,7 @@ class Operation;
/// A description of a (parallelizable) reduction in an affine loop.
struct LoopReduction {
/// Reduction kind.
- AtomicRMWKind kind;
+ arith::AtomicRMWKind kind;
/// Position of the iteration argument that acts as accumulator.
unsigned iterArgPosition;
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c04115476d35e..53b58fa23d342 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -13,7 +13,7 @@
#ifndef AFFINE_OPS
#define AFFINE_OPS
-include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
+include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
@@ -691,9 +691,9 @@ def AffineParallelOp : Affine_Op<"parallel",
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes,
- "ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
+ "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<int64_t>":$ranges)>,
OpBuilder<(ins "TypeRange":$resultTypes,
- "ArrayRef<AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
+ "ArrayRef<arith::AtomicRMWKind>":$reductions, "ArrayRef<AffineMap>":$lbMaps,
"ValueRange":$lbArgs, "ArrayRef<AffineMap>":$ubMaps, "ValueRange":$ubArgs,
"ArrayRef<int64_t>":$steps)>
];
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
index 4fa592d003a0f..31d6239388454 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h
@@ -109,6 +109,18 @@ bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
const APFloat &rhs);
+/// Returns the identity value attribute associated with an AtomicRMWKind op.
+Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+ OpBuilder &builder, Location loc);
+
+/// Returns the identity value associated with an AtomicRMWKind op.
+Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
+ Location loc);
+
+/// Returns the value obtained by applying the reduction operation kind
+/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
+Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
+ Value lhs, Value rhs);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
index 87439da956407..704edbb587bd3 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td
@@ -68,4 +68,28 @@ def Arith_CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir::arith";
}
+def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
+def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
+def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
+def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
+def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
+def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
+def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
+def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
+def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
+def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
+def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
+def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
+def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
+
+def AtomicRMWKindAttr : I64EnumAttr<
+ "AtomicRMWKind", "",
+ [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
+ ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
+ ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
+ ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
+ ATOMIC_RMW_KIND_ANDI]> {
+ let cppNamespace = "::mlir::arith";
+}
+
#endif // ARITHMETIC_BASE
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index e529a50dae935..884dc0f5b0510 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -11,6 +11,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
+include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
@@ -1673,4 +1674,51 @@ def MemRef_ViewOp : MemRef_Op<"view", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
+ AllTypesMatch<["value", "result"]>,
+ TypesMatchWith<"value type matches element type of memref",
+ "memref", "value",
+ "$_self.cast<MemRefType>().getElementType()">
+ ]> {
+ let summary = "atomic read-modify-write operation";
+ let description = [{
+ The `atomic_rmw` operation provides a way to perform a read-modify-write
+ sequence that is free from data races. The kind enumeration specifies the
+ modification to perform. The value operand represents the new value to be
+ applied during the modification. The memref operand represents the buffer
+ that the read and write will be performed against, as accessed by the
+ specified indices. The arity of the indices is the rank of the memref. The
+ result represents the latest value that was stored.
+
+ Example:
+
+ ```mlir
+ %x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
+ ```
+ }];
+
+ let arguments = (ins
+ AtomicRMWKindAttr:$kind,
+ AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
+ MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
+ Variadic<Index>:$indices);
+ let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
+
+ let assemblyFormat = [{
+ $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
+ type($memref) `)` `->` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return memref().getType().cast<MemRefType>();
+ }
+ }];
+ let hasFolder = 1;
+}
+
#endif // MEMREF_OPS
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index fe9c3c6e26f1b..b309488aa4f53 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -42,31 +42,4 @@ class PatternRewriter;
#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
-namespace mlir {
-
-/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
-/// comparison predicates.
-bool applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs,
- const APInt &rhs);
-
-/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
-/// comparison predicates.
-bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
- const APFloat &rhs);
-
-/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
- OpBuilder &builder, Location loc);
-
-/// Returns the identity value associated with an AtomicRMWKind op.
-Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
- Location loc);
-
-/// Returns the value obtained by applying the reduction operation kind
-/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
-Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
- Value lhs, Value rhs);
-
-} // namespace mlir
-
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2e50971db9e7a..794f0157ef144 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -13,7 +13,6 @@
#ifndef STANDARD_OPS
#define STANDARD_OPS
-include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
@@ -179,52 +178,6 @@ def AssertOp : Std_Op<"assert"> {
let hasCanonicalizeMethod = 1;
}
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-def AtomicRMWOp : Std_Op<"atomic_rmw", [
- AllTypesMatch<["value", "result"]>,
- TypesMatchWith<"value type matches element type of memref",
- "memref", "value",
- "$_self.cast<MemRefType>().getElementType()">
- ]> {
- let summary = "atomic read-modify-write operation";
- let description = [{
- The `atomic_rmw` operation provides a way to perform a read-modify-write
- sequence that is free from data races. The kind enumeration specifies the
- modification to perform. The value operand represents the new value to be
- applied during the modification. The memref operand represents the buffer
- that the read and write will be performed against, as accessed by the
- specified indices. The arity of the indices is the rank of the memref. The
- result represents the latest value that was stored.
-
- Example:
-
- ```mlir
- %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
- ```
- }];
-
- let arguments = (ins
- AtomicRMWKindAttr:$kind,
- AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value,
- MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
- Variadic<Index>:$indices);
- let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);
-
- let assemblyFormat = [{
- $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,`
- type($memref) `)` `->` type($result)
- }];
-
- let extraClassDeclaration = [{
- MemRefType getMemRefType() {
- return getMemref().getType().cast<MemRefType>();
- }
- }];
-}
-
def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td b/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
deleted file mode 100644
index 3016a197df0d0..0000000000000
--- a/mlir/include/mlir/Dialect/StandardOps/IR/StandardOpsBase.td
+++ /dev/null
@@ -1,42 +0,0 @@
-//===- StandardOpsBase.td - Standard ops definitions -------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Defines base support for standard operations.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef STANDARD_OPS_BASE
-#define STANDARD_OPS_BASE
-
-include "mlir/IR/OpBase.td"
-
-def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>;
-def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>;
-def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>;
-def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>;
-def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>;
-def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>;
-def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>;
-def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>;
-def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>;
-def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
-def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
-def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
-def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
-
-def AtomicRMWKindAttr : I64EnumAttr<
- "AtomicRMWKind", "",
- [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
- ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
- ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
- ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
- ATOMIC_RMW_KIND_ANDI]> {
- let cppNamespace = "::mlir";
-}
-
-#endif // STANDARD_OPS_BASE
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 14bd03968fcf6..816ec204acfe7 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_VECTOR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_VECTOROPS_H
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@@ -145,8 +146,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
/// Returns the value obtained by reducing the vector into a scalar using the
/// operation kind associated with a binary AtomicRMWKind op.
-Value getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
- Value vector);
+Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
+ Location loc, Value vector);
/// Return true if the last dimension of the MemRefType has unit stride. Also
/// return true for memrefs with no strides.
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 79a367e337133..c8022e0465483 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -40,7 +40,7 @@ using llvm::dbgs;
/// reduction kind suitable for use in affine parallel loop builder. If the
/// reduction is not supported, returns null.
static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
- AtomicRMWKind &kind) {
+ arith::AtomicRMWKind &kind) {
SmallVector<Operation *> combinerOps;
Value reducedVal =
matchReduction(forOp.getRegionIterArgs(), pos, combinerOps);
@@ -52,21 +52,21 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
return nullptr;
Operation *combinerOp = combinerOps.back();
- Optional<AtomicRMWKind> maybeKind =
- TypeSwitch<Operation *, Optional<AtomicRMWKind>>(combinerOp)
- .Case([](arith::AddFOp) { return AtomicRMWKind::addf; })
- .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; })
- .Case([](arith::AddIOp) { return AtomicRMWKind::addi; })
- .Case([](arith::AndIOp) { return AtomicRMWKind::andi; })
- .Case([](arith::OrIOp) { return AtomicRMWKind::ori; })
- .Case([](arith::MulIOp) { return AtomicRMWKind::muli; })
- .Case([](arith::MinFOp) { return AtomicRMWKind::minf; })
- .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; })
- .Case([](arith::MinSIOp) { return AtomicRMWKind::mins; })
- .Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; })
- .Case([](arith::MinUIOp) { return AtomicRMWKind::minu; })
- .Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; })
- .Default([](Operation *) -> Optional<AtomicRMWKind> {
+ Optional<arith::AtomicRMWKind> maybeKind =
+ TypeSwitch<Operation *, Optional<arith::AtomicRMWKind>>(combinerOp)
+ .Case([](arith::AddFOp) { return arith::AtomicRMWKind::addf; })
+ .Case([](arith::MulFOp) { return arith::AtomicRMWKind::mulf; })
+ .Case([](arith::AddIOp) { return arith::AtomicRMWKind::addi; })
+ .Case([](arith::AndIOp) { return arith::AtomicRMWKind::andi; })
+ .Case([](arith::OrIOp) { return arith::AtomicRMWKind::ori; })
+ .Case([](arith::MulIOp) { return arith::AtomicRMWKind::muli; })
+ .Case([](arith::MinFOp) { return arith::AtomicRMWKind::minf; })
+ .Case([](arith::MaxFOp) { return arith::AtomicRMWKind::maxf; })
+ .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; })
+ .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
+ .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
+ .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
+ .Default([](Operation *) -> Optional<arith::AtomicRMWKind> {
// TODO: AtomicRMW supports other kinds of reductions this is
// currently not detecting, add those when the need arises.
return llvm::None;
@@ -86,7 +86,7 @@ void mlir::getSupportedReductions(
return;
supportedReductions.reserve(numIterArgs);
for (unsigned i = 0; i < numIterArgs; ++i) {
- AtomicRMWKind kind;
+ arith::AtomicRMWKind kind;
if (Value value = getSupportedReduction(forOp, i, kind))
supportedReductions.emplace_back(LoopReduction{kind, i, value});
}
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 8b48549b38052..bc2f5917160e7 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -430,13 +430,14 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
// initialization of the result values.
Attribute reduction = std::get<0>(pair);
Type resultType = std::get<1>(pair);
- Optional<AtomicRMWKind> reductionOp = symbolizeAtomicRMWKind(
- static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
+ Optional<arith::AtomicRMWKind> reductionOp =
+ arith::symbolizeAtomicRMWKind(
+ static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
assert(reductionOp.hasValue() &&
"Reduction operation cannot be of None Type");
- AtomicRMWKind reductionOpValue = reductionOp.getValue();
+ arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
identityVals.push_back(
- getIdentityValue(reductionOpValue, resultType, rewriter, loc));
+ arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
parOp = rewriter.create<scf::ParallelOp>(
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
@@ -450,16 +451,17 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
"Unequal number of reductions and operands.");
for (unsigned i = 0, end = reductions.size(); i < end; i++) {
// For each of the reduction operations get the respective mlir::Value.
- Optional<AtomicRMWKind> reductionOp =
- symbolizeAtomicRMWKind(reductions[i].cast<IntegerAttr>().getInt());
+ Optional<arith::AtomicRMWKind> reductionOp =
+ arith::symbolizeAtomicRMWKind(
+ reductions[i].cast<IntegerAttr>().getInt());
assert(reductionOp.hasValue() &&
"Reduction Operation cannot be of None Type");
- AtomicRMWKind reductionOpValue = reductionOp.getValue();
+ arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
rewriter.setInsertionPoint(&parOp.getBody()->back());
auto reduceOp = rewriter.create<scf::ReduceOp>(
loc, affineParOpTerminator->getOperand(i));
rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
- Value reductionResult = getReductionOp(
+ Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc,
reduceOp.getReductionOperator().front().getArgument(0),
reduceOp.getReductionOperator().front().getArgument(1));
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 28981dd87ecc9..b1f7d0452ee13 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1553,6 +1553,62 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
}
};
+//===----------------------------------------------------------------------===//
+// AtomicRMWOpLowering
+//===----------------------------------------------------------------------===//
+
+/// Try to match the kind of a std.atomic_rmw to determine whether to use a
+/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
+static Optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
+ switch (atomicOp.kind()) {
+ case arith::AtomicRMWKind::addf:
+ return LLVM::AtomicBinOp::fadd;
+ case arith::AtomicRMWKind::addi:
+ return LLVM::AtomicBinOp::add;
+ case arith::AtomicRMWKind::assign:
+ return LLVM::AtomicBinOp::xchg;
+ case arith::AtomicRMWKind::maxs:
+ return LLVM::AtomicBinOp::max;
+ case arith::AtomicRMWKind::maxu:
+ return LLVM::AtomicBinOp::umax;
+ case arith::AtomicRMWKind::mins:
+ return LLVM::AtomicBinOp::min;
+ case arith::AtomicRMWKind::minu:
+ return LLVM::AtomicBinOp::umin;
+ case arith::AtomicRMWKind::ori:
+ return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::andi:
+ return LLVM::AtomicBinOp::_and;
+ default:
+ return llvm::None;
+ }
+ llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (failed(match(atomicOp)))
+ return failure();
+ auto maybeKind = matchSimpleAtomicOp(atomicOp);
+ if (!maybeKind)
+ return failure();
+ auto resultType = adaptor.value().getType();
+ auto memRefType = atomicOp.getMemRefType();
+ auto dataPtr =
+ getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter);
+ rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
+ atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(),
+ LLVM::AtomicOrdering::acq_rel);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1561,6 +1617,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<
AllocaOpLowering,
AllocaScopeOpLowering,
+ AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
DimOpLowering,
GlobalMemrefOpLowering,
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index da429dd8af119..feaa140cc710e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -772,61 +772,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
-} // namespace
-
-/// Try to match the kind of a std.atomic_rmw to determine whether to use a
-/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
-static Optional<LLVM::AtomicBinOp> matchSimpleAtomicOp(AtomicRMWOp atomicOp) {
- switch (atomicOp.getKind()) {
- case AtomicRMWKind::addf:
- return LLVM::AtomicBinOp::fadd;
- case AtomicRMWKind::addi:
- return LLVM::AtomicBinOp::add;
- case AtomicRMWKind::assign:
- return LLVM::AtomicBinOp::xchg;
- case AtomicRMWKind::maxs:
- return LLVM::AtomicBinOp::max;
- case AtomicRMWKind::maxu:
- return LLVM::AtomicBinOp::umax;
- case AtomicRMWKind::mins:
- return LLVM::AtomicBinOp::min;
- case AtomicRMWKind::minu:
- return LLVM::AtomicBinOp::umin;
- case AtomicRMWKind::ori:
- return LLVM::AtomicBinOp::_or;
- case AtomicRMWKind::andi:
- return LLVM::AtomicBinOp::_and;
- default:
- return llvm::None;
- }
- llvm_unreachable("Invalid AtomicRMWKind");
-}
-
-namespace {
-
-struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
- using Base::Base;
-
- LogicalResult
- matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- if (failed(match(atomicOp)))
- return failure();
- auto maybeKind = matchSimpleAtomicOp(atomicOp);
- if (!maybeKind)
- return failure();
- auto resultType = adaptor.getValue().getType();
- auto memRefType = atomicOp.getMemRefType();
- auto dataPtr =
- getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
- adaptor.getIndices(), rewriter);
- rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
- atomicOp, resultType, *maybeKind, dataPtr, adaptor.getValue(),
- LLVM::AtomicOrdering::acq_rel);
- return success();
- }
-};
-
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
@@ -962,7 +907,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
// clang-format off
patterns.add<
AssertOpLowering,
- AtomicRMWOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
CallOpLowering,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 071838dcf2be4..c3c1b51294801 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2801,7 +2801,7 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
- ArrayRef<AtomicRMWKind> reductions,
+ ArrayRef<arith::AtomicRMWKind> reductions,
ArrayRef<int64_t> ranges) {
SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
@@ -2814,7 +2814,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
TypeRange resultTypes,
- ArrayRef<AtomicRMWKind> reductions,
+ ArrayRef<arith::AtomicRMWKind> reductions,
ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
ArrayRef<int64_t> steps) {
@@ -2843,7 +2843,7 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
// Convert the reductions to integer attributes.
SmallVector<Attribute, 4> reductionAttrs;
- for (AtomicRMWKind reduction : reductions)
+ for (arith::AtomicRMWKind reduction : reductions)
reductionAttrs.push_back(
builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
result.addAttribute(getReductionsAttrName(),
@@ -3050,7 +3050,7 @@ static LogicalResult verify(AffineParallelOp op) {
// Verify reduction ops are all valid
for (Attribute attr : op.reductions()) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
- if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
+ if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
return op.emitOpError("invalid reduction attribute");
}
@@ -3150,9 +3150,9 @@ static void print(OpAsmPrinter &p, AffineParallelOp op) {
if (op.getNumResults()) {
p << " reduce (";
llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
- AtomicRMWKind sym =
- *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
- p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
+ arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
+ attr.template cast<IntegerAttr>().getInt());
+ p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
});
p << ") -> (" << op.getResultTypes() << ")";
}
@@ -3374,8 +3374,8 @@ static ParseResult parseAffineParallelOp(OpAsmParser &parser,
if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
attrStorage))
return failure();
- llvm::Optional<AtomicRMWKind> reduction =
- symbolizeAtomicRMWKind(attrVal.getValue());
+ llvm::Optional<arith::AtomicRMWKind> reduction =
+ arith::symbolizeAtomicRMWKind(attrVal.getValue());
if (!reduction)
return parser.emitError(loc, "invalid reduction value: ") << attrVal;
reductions.push_back(builder.getI64IntegerAttr(
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 9d59b89ea1a2c..7ecc6750bcca4 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -971,7 +971,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
/// Creates a constant vector filled with the neutral elements of the given
/// reduction. The scalar type of vector elements will be taken from
/// `oldOperand`.
-static arith::ConstantOp createInitialVector(AtomicRMWKind reductionKind,
+static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
Value oldOperand,
VectorizationState &state) {
Type scalarTy = oldOperand.getType();
@@ -1245,8 +1245,8 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
/// Returns true if `value` is a constant equal to the neutral element of the
/// given vectorizable reduction.
-static bool isNeutralElementConst(AtomicRMWKind reductionKind, Value value,
- VectorizationState &state) {
+static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
+ Value value, VectorizationState &state) {
Type scalarTy = value.getType();
if (!VectorType::isValidElementType(scalarTy))
return false;
@@ -1361,7 +1361,8 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp,
Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
Value finalRes = reducedRes;
if (!isNeutralElementConst(reductions[i].kind, origInit, state))
- finalRes = getReductionOp(reductions[i].kind, state.builder,
+ finalRes =
+ arith::getReductionOp(reductions[i].kind, state.builder,
reducedRes.getLoc(), reducedRes, origInit);
state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
}
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index f0ce1b7a4d704..048e4d89186cd 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
@@ -1208,6 +1209,101 @@ OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
return BoolAttr::get(getContext(), val);
}
+//===----------------------------------------------------------------------===//
+// Atomic Enum
+//===----------------------------------------------------------------------===//
+
+/// Returns the identity value attribute associated with an AtomicRMWKind op.
+Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
+ OpBuilder &builder, Location loc) {
+ switch (kind) {
+ case AtomicRMWKind::maxf:
+ return builder.getFloatAttr(
+ resultType,
+ APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+ /*Negative=*/true));
+ case AtomicRMWKind::addf:
+ case AtomicRMWKind::addi:
+ case AtomicRMWKind::maxu:
+ case AtomicRMWKind::ori:
+ return builder.getZeroAttr(resultType);
+ case AtomicRMWKind::andi:
+ return builder.getIntegerAttr(
+ resultType,
+ APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
+ case AtomicRMWKind::maxs:
+ return builder.getIntegerAttr(
+ resultType,
+ APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
+ case AtomicRMWKind::minf:
+ return builder.getFloatAttr(
+ resultType,
+ APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
+ /*Negative=*/false));
+ case AtomicRMWKind::mins:
+ return builder.getIntegerAttr(
+ resultType,
+ APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
+ case AtomicRMWKind::minu:
+ return builder.getIntegerAttr(
+ resultType,
+ APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
+ case AtomicRMWKind::muli:
+ return builder.getIntegerAttr(resultType, 1);
+ case AtomicRMWKind::mulf:
+ return builder.getFloatAttr(resultType, 1);
+ // TODO: Add remaining reduction operations.
+ default:
+ (void)emitOptionalError(loc, "Reduction operation type not supported");
+ break;
+ }
+ return nullptr;
+}
+
+/// Returns the identity value associated with an AtomicRMWKind op.
+Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
+ OpBuilder &builder, Location loc) {
+ Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
+ return builder.create<arith::ConstantOp>(loc, attr);
+}
+
+/// Return the value obtained by applying the reduction operation kind
+/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
+Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
+ Location loc, Value lhs, Value rhs) {
+ switch (op) {
+ case AtomicRMWKind::addf:
+ return builder.create<arith::AddFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::addi:
+ return builder.create<arith::AddIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::mulf:
+ return builder.create<arith::MulFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::muli:
+ return builder.create<arith::MulIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::maxf:
+ return builder.create<arith::MaxFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::minf:
+ return builder.create<arith::MinFOp>(loc, lhs, rhs);
+ case AtomicRMWKind::maxs:
+ return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::mins:
+ return builder.create<arith::MinSIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::maxu:
+ return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::minu:
+ return builder.create<arith::MinUIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::ori:
+ return builder.create<arith::OrIOp>(loc, lhs, rhs);
+ case AtomicRMWKind::andi:
+ return builder.create<arith::AndIOp>(loc, lhs, rhs);
+ // TODO: Add remaining reduction operations.
+ default:
+ (void)emitOptionalError(loc, "Reduction operation type not supported");
+ break;
+ }
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ab7e8305ab5b8..45ba726d5bf99 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2286,6 +2286,50 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
}
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AtomicRMWOp op) {
+ if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
+ return op.emitOpError(
+ "expects the number of subscripts to be equal to memref rank");
+ switch (op.kind()) {
+ case arith::AtomicRMWKind::addf:
+ case arith::AtomicRMWKind::maxf:
+ case arith::AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::mulf:
+ if (!op.value().getType().isa<FloatType>())
+ return op.emitOpError()
+ << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
+ << "' expects a floating-point type";
+ break;
+ case arith::AtomicRMWKind::addi:
+ case arith::AtomicRMWKind::maxs:
+ case arith::AtomicRMWKind::maxu:
+ case arith::AtomicRMWKind::mins:
+ case arith::AtomicRMWKind::minu:
+ case arith::AtomicRMWKind::muli:
+ case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::andi:
+ if (!op.value().getType().isa<IntegerType>())
+ return op.emitOpError()
+ << "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
+ << "' expects an integer type";
+ break;
+ default:
+ break;
+ }
+ return success();
+}
+
+OpFoldResult AtomicRMWOp::fold(ArrayRef<Attribute> operands) {
+ /// atomicrmw(memrefcast) -> atomicrmw
+ if (succeeded(foldMemRefCast(*this, value())))
+ return getResult();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 02d54472baf50..a74b46c034c42 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -131,134 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
-//===----------------------------------------------------------------------===//
-// AtomicRMWOp
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verify(AtomicRMWOp op) {
- if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
- return op.emitOpError(
- "expects the number of subscripts to be equal to memref rank");
- switch (op.getKind()) {
- case AtomicRMWKind::addf:
- case AtomicRMWKind::maxf:
- case AtomicRMWKind::minf:
- case AtomicRMWKind::mulf:
- if (!op.getValue().getType().isa<FloatType>())
- return op.emitOpError()
- << "with kind '" << stringifyAtomicRMWKind(op.getKind())
- << "' expects a floating-point type";
- break;
- case AtomicRMWKind::addi:
- case AtomicRMWKind::maxs:
- case AtomicRMWKind::maxu:
- case AtomicRMWKind::mins:
- case AtomicRMWKind::minu:
- case AtomicRMWKind::muli:
- case AtomicRMWKind::ori:
- case AtomicRMWKind::andi:
- if (!op.getValue().getType().isa<IntegerType>())
- return op.emitOpError()
- << "with kind '" << stringifyAtomicRMWKind(op.getKind())
- << "' expects an integer type";
- break;
- default:
- break;
- }
- return success();
-}
-
-/// Returns the identity value attribute associated with an AtomicRMWKind op.
-Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
- OpBuilder &builder, Location loc) {
- switch (kind) {
- case AtomicRMWKind::maxf:
- return builder.getFloatAttr(
- resultType,
- APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
- /*Negative=*/true));
- case AtomicRMWKind::addf:
- case AtomicRMWKind::addi:
- case AtomicRMWKind::maxu:
- case AtomicRMWKind::ori:
- return builder.getZeroAttr(resultType);
- case AtomicRMWKind::andi:
- return builder.getIntegerAttr(
- resultType,
- APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
- case AtomicRMWKind::maxs:
- return builder.getIntegerAttr(
- resultType,
- APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
- case AtomicRMWKind::minf:
- return builder.getFloatAttr(
- resultType,
- APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
- /*Negative=*/false));
- case AtomicRMWKind::mins:
- return builder.getIntegerAttr(
- resultType,
- APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
- case AtomicRMWKind::minu:
- return builder.getIntegerAttr(
- resultType,
- APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
- case AtomicRMWKind::muli:
- return builder.getIntegerAttr(resultType, 1);
- case AtomicRMWKind::mulf:
- return builder.getFloatAttr(resultType, 1);
- // TODO: Add remaining reduction operations.
- default:
- (void)emitOptionalError(loc, "Reduction operation type not supported");
- break;
- }
- return nullptr;
-}
-
-/// Returns the identity value associated with an AtomicRMWKind op.
-Value mlir::getIdentityValue(AtomicRMWKind op, Type resultType,
- OpBuilder &builder, Location loc) {
- Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
- return builder.create<arith::ConstantOp>(loc, attr);
-}
-
-/// Return the value obtained by applying the reduction operation kind
-/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
-Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
- Value lhs, Value rhs) {
- switch (op) {
- case AtomicRMWKind::addf:
- return builder.create<arith::AddFOp>(loc, lhs, rhs);
- case AtomicRMWKind::addi:
- return builder.create<arith::AddIOp>(loc, lhs, rhs);
- case AtomicRMWKind::mulf:
- return builder.create<arith::MulFOp>(loc, lhs, rhs);
- case AtomicRMWKind::muli:
- return builder.create<arith::MulIOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxf:
- return builder.create<arith::MaxFOp>(loc, lhs, rhs);
- case AtomicRMWKind::minf:
- return builder.create<arith::MinFOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxs:
- return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
- case AtomicRMWKind::mins:
- return builder.create<arith::MinSIOp>(loc, lhs, rhs);
- case AtomicRMWKind::maxu:
- return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
- case AtomicRMWKind::minu:
- return builder.create<arith::MinUIOp>(loc, lhs, rhs);
- case AtomicRMWKind::ori:
- return builder.create<arith::OrIOp>(loc, lhs, rhs);
- case AtomicRMWKind::andi:
- return builder.create<arith::AndIOp>(loc, lhs, rhs);
- // TODO: Add remaining reduction operations.
- default:
- (void)emitOptionalError(loc, "Reduction operation type not supported");
- break;
- }
- return nullptr;
-}
-
//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 71a1a55903c65..a62f6a076f936 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -40,18 +40,18 @@ namespace {
/// %new_value = select %cmp, %current, %fval : f32
/// atomic_yield %new_value : f32
/// }
-struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
+struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(AtomicRMWOp op,
+ LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
PatternRewriter &rewriter) const final {
arith::CmpFPredicate predicate;
- switch (op.getKind()) {
- case AtomicRMWKind::maxf:
+ switch (op.kind()) {
+ case arith::AtomicRMWKind::maxf:
predicate = arith::CmpFPredicate::OGT;
break;
- case AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::minf:
predicate = arith::CmpFPredicate::OLT;
break;
default:
@@ -59,13 +59,13 @@ struct AtomicRMWOpConverter : public OpRewritePattern<AtomicRMWOp> {
}
auto loc = op.getLoc();
- auto genericOp = rewriter.create<GenericAtomicRMWOp>(loc, op.getMemref(),
- op.getIndices());
+ auto genericOp =
+ rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
Value lhs = genericOp.getCurrentValue();
- Value rhs = op.getValue();
+ Value rhs = op.value();
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
bodyBuilder.create<AtomicYieldOp>(loc, select);
@@ -130,10 +130,11 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect>();
- target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
- return op.getKind() != AtomicRMWKind::maxf &&
- op.getKind() != AtomicRMWKind::minf;
- });
+ target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
+ [](memref::AtomicRMWOp op) {
+ return op.kind() != arith::AtomicRMWKind::maxf &&
+ op.kind() != arith::AtomicRMWKind::minf;
+ });
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
});
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 3a65d0e93dfdb..60c0aac1a4bea 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -359,41 +359,42 @@ static void print(OpAsmPrinter &p, ReductionOp op) {
p << " : " << op.vector().getType() << " into " << op.dest().getType();
}
-Value mlir::vector::getVectorReductionOp(AtomicRMWKind op, OpBuilder &builder,
- Location loc, Value vector) {
+Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
+ OpBuilder &builder, Location loc,
+ Value vector) {
Type scalarType = vector.getType().cast<ShapedType>().getElementType();
switch (op) {
- case AtomicRMWKind::addf:
- case AtomicRMWKind::addi:
+ case arith::AtomicRMWKind::addf:
+ case arith::AtomicRMWKind::addi:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("add"),
vector, ValueRange{});
- case AtomicRMWKind::mulf:
- case AtomicRMWKind::muli:
+ case arith::AtomicRMWKind::mulf:
+ case arith::AtomicRMWKind::muli:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("mul"),
vector, ValueRange{});
- case AtomicRMWKind::minf:
+ case arith::AtomicRMWKind::minf:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minf"),
vector, ValueRange{});
- case AtomicRMWKind::mins:
+ case arith::AtomicRMWKind::mins:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minsi"),
vector, ValueRange{});
- case AtomicRMWKind::minu:
+ case arith::AtomicRMWKind::minu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("minui"),
vector, ValueRange{});
- case AtomicRMWKind::maxf:
+ case arith::AtomicRMWKind::maxf:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxf"),
vector, ValueRange{});
- case AtomicRMWKind::maxs:
+ case arith::AtomicRMWKind::maxs:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxsi"),
vector, ValueRange{});
- case AtomicRMWKind::maxu:
+ case arith::AtomicRMWKind::maxu:
return builder.create<vector::ReductionOp>(vector.getLoc(), scalarType,
builder.getStringAttr("maxui"),
vector, ValueRange{});
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index fbb79b5af3f86..91d4a7cd5d195 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1551,7 +1551,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp,
for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
rhs = forOp.getResult(i * oldNumResults + pos);
// Create ops based on reduction type.
- lhs = getReductionOp(reduction.kind, builder, loc, lhs, rhs);
+ lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
if (!lhs)
return failure();
Operation *op = lhs.getDefiningOp();
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5682c853964c8..70ba47d2d176b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -859,3 +859,28 @@ func @rank_of_ranked(%ranked: memref<?xi32>) {
}
// CHECK: llvm.mlir.constant(1 : index) : i64
// CHECK32: llvm.mlir.constant(1 : index) : i32
+
+// -----
+
+// CHECK-LABEL: func @atomic_rmw
+func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
+ memref.atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
+ memref.atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
+ // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
+ return
+}
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index c3282e1903d6f..0dc6bf10dc5ea 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -486,31 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// -----
-// CHECK-LABEL: func @atomic_rmw
-func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) {
- atomic_rmw assign %fval, %F[%i] : (f32, memref<10xf32>) -> f32
- // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel
- atomic_rmw addi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel
- atomic_rmw maxs %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel
- atomic_rmw mins %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel
- atomic_rmw maxu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel
- atomic_rmw minu %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel
- atomic_rmw addf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
- // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel
- atomic_rmw ori %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel
- atomic_rmw andi %ival, %I[%i] : (i32, memref<10xi32>) -> i32
- // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel
- return
-}
-
-// -----
-
// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 39f9847f4c9e2..2e81705049f54 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -499,3 +499,14 @@ func @no_fold_dynamic_no_op_subview(%arg0 : memref<?x?xf32>) -> memref<?x?xf32,
// CHECK-LABEL: func @no_fold_dynamic_no_op_subview(
// CHECK: %[[SUBVIEW:.+]] = memref.subview
// CHECK: return %[[SUBVIEW]]
+
+// -----
+
+func @atomicrmw_cast_fold(%arg0 : f32, %arg1 : memref<4xf32>, %c : index) {
+ %v = memref.cast %arg1 : memref<4xf32> to memref<?xf32>
+ %a = memref.atomic_rmw addf %arg0, %v[%c] : (f32, memref<?xf32>) -> f32
+ return
+}
+
+// CHECK-LABEL: func @atomicrmw_cast_fold
+// CHECK-NEXT: memref.atomic_rmw addf %arg0, %arg1[%arg2] : (f32, memref<4xf32>) -> f32
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 5cf32703c9ebb..90f851959748c 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -848,3 +848,27 @@ func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : inde
// expected-error at +1 {{expected 3 offset values}}
%0 = memref.subview %arg0[0, 0] [%arg1, %arg2] [1, 1] : memref<?x?x?xf32> to memref<?x?x?xf32, #map>
}
+
+// -----
+
+func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
+ // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
+ %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
+ return
+}
+
+// -----
+
+func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
+ // expected-error at +1 {{expects a floating-point type}}
+ %x = memref.atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
+ return
+}
+
+// -----
+
+func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
+ // expected-error at +1 {{expects an integer type}}
+ %x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 963c817af3981..71b6038a2f9d2 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -227,3 +227,13 @@ func @rank(%t : memref<4x4x?xf32>) {
%1 = memref.rank %t : memref<4x4x?xf32>
return
}
+
+// ------
+
+// CHECK-LABEL: func @atomic_rmw
+// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
+func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
+ %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
+ // CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
+ return
+}
diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index 45659aee0763d..cb650ffd11bd5 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -3,7 +3,7 @@
// CHECK-LABEL: func @atomic_rmw_to_generic
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
- %x = atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
@@ -18,7 +18,7 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
// CHECK-LABEL: func @atomic_rmw_no_conversion
func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
- %x = atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
+ %x = memref.atomic_rmw addf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK-NOT: generic_atomic_rmw
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index b83f530eeacc6..351e8a6b39c1b 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -325,14 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
return
}
-// CHECK-LABEL: func @atomic_rmw
-// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
-func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
- %x = atomic_rmw addf %val, %I[%i] : (f32, memref<10xf32>) -> f32
- // CHECK: atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
- return
-}
-
// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index 49f29f09bf492..2aae390af4d02 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -130,30 +130,6 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
// -----
-func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
- // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
- %x = atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
- return
-}
-
-// -----
-
-func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) {
- // expected-error at +1 {{expects a floating-point type}}
- %x = atomic_rmw addf %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32
- return
-}
-
-// -----
-
-func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
- // expected-error at +1 {{expects an integer type}}
- %x = atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
- return
-}
-
-// -----
-
func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error at +1 {{expected single number of entry block arguments}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
More information about the Mlir-commits
mailing list