[Mlir-commits] [mlir] 5bff523 - [mlir][arith] Add expansion pattern for ext/trunc of bf16
Robert Suderman
llvmlistbot at llvm.org
Wed Mar 29 17:59:08 PDT 2023
Author: Robert Suderman
Date: 2023-03-30T00:51:06Z
New Revision: 5bff523793ee8c30c260cc77b23c61dcbb606486
URL: https://github.com/llvm/llvm-project/commit/5bff523793ee8c30c260cc77b23c61dcbb606486
DIFF: https://github.com/llvm/llvm-project/commit/5bff523793ee8c30c260cc77b23c61dcbb606486.diff
LOG: [mlir][arith] Add expansion pattern for ext/trunc of bf16
bf16 has a trivial truncation/extension behavior with F32 that
can be described in elementary arith operations. Include some
expansions to efficiently convert.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D147091
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
mlir/test/Dialect/Arith/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 257a62aa39f78..6d60f8aefd63c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -38,6 +38,9 @@ void populateArithWideIntEmulationPatterns(
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
+/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
+void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8f34531937c5c..7a6246991cff3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -25,15 +26,13 @@ using namespace mlir;
/// Create an integer or index constant.
static Value createConst(Location loc, Type type, int value,
PatternRewriter &rewriter) {
-
- auto elTy = getElementTypeOrSelf(type);
- auto constantAttr = rewriter.getIntegerAttr(elTy, value);
-
- if (auto vecTy = llvm::dyn_cast<ShapedType>(type))
+ auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
- loc, vecTy, DenseElementsAttr::get(vecTy, constantAttr));
+ loc, DenseElementsAttr::get(shapedTy, attr));
+ }
- return rewriter.create<arith::ConstantOp>(loc, constantAttr);
+ return rewriter.create<arith::ConstantOp>(loc, attr);
}
namespace {
@@ -187,6 +186,73 @@ struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
}
};
+struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::ExtFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isBF16() || !resultETy.isF32()) {
+ return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
+ }
+
+ Type i16Ty = b.getI16Type();
+ Type i32Ty = b.getI32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i16Ty = shapedTy.clone(i16Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
+ Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
+
+ Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
+ Value shl = b.create<arith::ShLIOp>(exti, c16);
+ Value result = b.create<arith::BitcastOp>(resultTy, shl);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!operandETy.isF32() || !resultETy.isBF16()) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
+ }
+
+ Type i16Ty = b.getI16Type();
+ Type i32Ty = b.getI32Type();
+ if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
+ i16Ty = shapedTy.clone(i16Ty);
+ i32Ty = shapedTy.clone(i32Ty);
+ }
+
+ Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+ Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
+ Value shl = b.create<arith::ShRUIOp>(bitcast, c16);
+ Value trunc = b.create<arith::TruncIOp>(i16Ty, shl);
+ Value result = b.create<arith::BitcastOp>(resultTy, trunc);
+
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsBase<ArithExpandOpsPass> {
void runOnOperation() override {
@@ -204,6 +270,21 @@ struct ArithExpandOpsPass
arith::MaxFOp,
arith::MinFOp
>();
+
+ target.addDynamicallyLegalOp<arith::ExtFOp>(
+ [](arith::ExtFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ return !(inETy.isBF16() && outETy.isF32());
+ });
+
+ target.addDynamicallyLegalOp<arith::TruncFOp>(
+ [](arith::TruncFOp op) {
+ Type inETy = getElementTypeOrSelf(op.getOperand().getType());
+ Type outETy = getElementTypeOrSelf(op.getType());
+ return !(inETy.isF32() && outETy.isBF16());
+ });
+
// clang-format on
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -220,12 +301,19 @@ void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
patterns.getContext());
}
+void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
+ patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
+ patterns.getContext());
+}
+
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
// clang-format off
patterns.add<
MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
- MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>
+ MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
+ BFloat16ExtFOpConverter,
+ BFloat16TruncFOpConverter
>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 7b7eb4003956a..ba87e2907abb1 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -215,3 +215,67 @@ func.func @minf(%a: f32, %b: f32) -> f32 {
// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+func.func @extf_bf16(%arg0 : bf16) -> f32 {
+ %0 = arith.extf %arg0 : bf16 to f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: @extf_bf16
+// CHECK-SAME: %[[ARG0:.+]]: bf16
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : bf16 to i16
+// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : i16 to i32
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16
+// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]]
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : i32 to f32
+// CHECK: return %[[BITCAST]]
+
+// -----
+
+func.func @extf_vector_bf16(%arg0 : vector<4xbf16>) -> vector<4xf32> {
+ %0 = arith.extf %arg0 : vector<4xbf16> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extf_vector_bf16
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xbf16>
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xbf16> to vector<4xi16>
+// CHECK-DAG: %[[EXT:.+]] = arith.extui %[[BITCAST]] : vector<4xi16> to vector<4xi32>
+// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16>
+// CHECK-DAG: %[[SHLI:.+]] = arith.shli %[[EXT]], %[[C16]]
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[SHLI]] : vector<4xi32> to vector<4xf32>
+// CHECK: return %[[BITCAST]]
+
+// -----
+
+func.func @truncf_f32(%arg0 : f32) -> bf16 {
+ %0 = arith.truncf %arg0 : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncf_f32
+// CHECK-SAME: %[[ARG0:.+]]: f32
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : f32 to i32
+// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]]
+// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i16
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : i16 to bf16
+// CHECK: return %[[BITCAST]] : bf16
+
+// -----
+
+func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
+ %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16>
+ return %0 : vector<4xbf16>
+}
+
+// CHECK-LABEL: @truncf_vector_f32
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xf32>
+// CHECK-DAG: %[[C16:.+]] = arith.constant dense<16>
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[ARG0]] : vector<4xf32> to vector<4xi32>
+// CHECK-DAG: %[[SHR:.+]] = arith.shrui %[[BITCAST]], %[[C16]]
+// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : vector<4xi32> to vector<4xi16>
+// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %[[TRUNC]] : vector<4xi16> to vector<4xbf16>
+// CHECK: return %[[BITCAST]] : vector<4xbf16>
More information about the Mlir-commits
mailing list