[Mlir-commits] [mlir] 04b449e - The fillOp's value needs to casted
Prashant Kumar
llvmlistbot at llvm.org
Wed Nov 9 19:43:30 PST 2022
Author: Prashant Kumar
Date: 2022-11-10T03:43:22Z
New Revision: 04b449e147f6be1b466455639055019e508f4137
URL: https://github.com/llvm/llvm-project/commit/04b449e147f6be1b466455639055019e508f4137
DIFF: https://github.com/llvm/llvm-project/commit/04b449e147f6be1b466455639055019e508f4137.diff
LOG: The fillOp's value needs to casted
During elementwise fusion the fillOp's value was directly
referred without casting which can create mismatching dtypes.
Reviewed By: mravishankar, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D137447
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Utils/Utils.h
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d3073309963ec..d7aa7dbb67bef 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -92,6 +92,12 @@ SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec);
+/// Converts a scalar value `operand` to type `toType`. If the value doesn't
+/// convert, a warning will be issued and the operand is returned as is (which
+/// will presumably yield a verification issue downstream).
+Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
+ Type toType, bool isUnsignedCast);
+
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 210c580b708b7..cf9fdc232e1a8 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -80,6 +80,50 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
}
+Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
+ Type toType, bool isUnsignedCast) {
+ if (operand.getType() == toType)
+ return operand;
+ if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+ // If operand is floating point, cast directly to the int type.
+ if (operand.getType().isa<FloatType>()) {
+ if (isUnsignedCast)
+ return b.create<arith::FPToUIOp>(loc, toType, operand);
+ return b.create<arith::FPToSIOp>(loc, toType, operand);
+ }
+ // Cast index operands directly to the int type.
+ if (operand.getType().isIndex())
+ return b.create<arith::IndexCastOp>(loc, toType, operand);
+ if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+ // Either extend or truncate.
+ if (toIntType.getWidth() > fromIntType.getWidth()) {
+ if (isUnsignedCast)
+ return b.create<arith::ExtUIOp>(loc, toType, operand);
+ return b.create<arith::ExtSIOp>(loc, toType, operand);
+ }
+ if (toIntType.getWidth() < fromIntType.getWidth())
+ return b.create<arith::TruncIOp>(loc, toType, operand);
+ }
+ } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+ // If operand is integer, cast directly to the float type.
+ // Note that it is unclear how to cast from BF16<->FP16.
+ if (operand.getType().isa<IntegerType>()) {
+ if (isUnsignedCast)
+ return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
+ return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
+ }
+ if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+ if (toFloatType.getWidth() > fromFloatType.getWidth())
+ return b.create<arith::ExtFOp>(loc, toFloatType, operand);
+ if (toFloatType.getWidth() < fromFloatType.getWidth())
+ return b.create<arith::TruncFOp>(loc, toFloatType, operand);
+ }
+ }
+ emitWarning(loc) << "could not cast operand of type " << operand.getType()
+ << " to " << toType;
+ return operand;
+}
+
SmallVector<Value>
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8ce1ad070f46a..fc62407e5d375 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -423,48 +423,7 @@ class RegionBuilderHelper {
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
-
- if (operand.getType() == toType)
- return operand;
- if (auto toIntType = toType.dyn_cast<IntegerType>()) {
- // If operand is floating point, cast directly to the int type.
- if (operand.getType().isa<FloatType>()) {
- if (isUnsignedCast)
- return builder.create<arith::FPToUIOp>(loc, toType, operand);
- return builder.create<arith::FPToSIOp>(loc, toType, operand);
- }
- // Cast index operands directly to the int type.
- if (operand.getType().isIndex())
- return builder.create<arith::IndexCastOp>(loc, toType, operand);
- if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
- // Either extend or truncate.
- if (toIntType.getWidth() > fromIntType.getWidth()) {
- if (isUnsignedCast)
- return builder.create<arith::ExtUIOp>(loc, toType, operand);
- return builder.create<arith::ExtSIOp>(loc, toType, operand);
- }
- if (toIntType.getWidth() < fromIntType.getWidth())
- return builder.create<arith::TruncIOp>(loc, toType, operand);
- }
- } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
- // If operand is integer, cast directly to the float type.
- // Note that it is unclear how to cast from BF16<->FP16.
- if (operand.getType().isa<IntegerType>()) {
- if (isUnsignedCast)
- return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
- return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
- }
- if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
- if (toFloatType.getWidth() > fromFloatType.getWidth())
- return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
- if (toFloatType.getWidth() < fromFloatType.getWidth())
- return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
- }
- }
-
- emitWarning(operand.getLoc()) << "could not cast operand of type "
- << operand.getType() << " to " << toType;
- return operand;
+ return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6a9c4e36a07e1..e639158cb7229 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1744,8 +1744,14 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
if (!fillOp)
continue;
fillFound = true;
+ Value fillVal = fillOp.value();
+ auto resultType =
+ fillOp.result().getType().cast<RankedTensorType>().getElementType();
+ Value convertedVal =
+ convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
+ /*isUnsignedCast =*/false);
payload.getArgument(opOperand->getOperandNumber())
- .replaceAllUsesWith(fillOp.value());
+ .replaceAllUsesWith(convertedVal);
}
return success(fillFound);
}
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index ca142e3fe0ad2..0de109ea83022 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1017,6 +1017,30 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// -----
+// CHECK-LABEL: func @fold_fill_generic_
diff erent_dtype
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
+// CHECK-NOT: linalg.fill
+// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
+// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
+#map0 = affine_map<(d0) -> (d0)>
+func.func @fold_fill_generic_
diff erent_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 7.0 : f32
+ %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
+ %1 = tensor.empty(%0) : tensor<?xf16>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
+ %3 = tensor.empty(%0) : tensor<?xf16>
+ %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
+ ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
+ %5 = arith.addf %arg1, %arg2 : f16
+ linalg.yield %5 : f16
+ } -> tensor<?xf16>
+ return %4 : tensor<?xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
More information about the Mlir-commits
mailing list