[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (PR #142797)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 6 10:25:06 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.
---
Patch is 87.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142797.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+163)
- (added) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir (+1048)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..e1687031d259a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,15 +8,18 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
namespace mlir {
namespace xegpu {
@@ -314,6 +317,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
+template <typename Op>
+struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
+ using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands/results must be 1D or 2D vectors
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
+ return rewriter.notifyMatchFailure(
+ op, "Result type is not a 1D or 2D vector");
+
+ ArrayRef<int64_t> shape = resultType.getShape();
+ for (Value operand : op->getOperands()) {
+ auto operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != resultType.getRank() ||
+ operandType.getShape() != shape) {
+ return rewriter.notifyMatchFailure(
+ op, "Operand type is not a 1D or 2D vector with the same shape as "
+ "result type");
+ }
+ }
+
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Operation does not have a valid layout attribute for subgroup "
+ "distribution");
+
+ // Extract sgShape from layout
+ SmallVector<int64_t> sgShape;
+ if (auto sgDataAttr = layout.getSgData()) {
+ sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+ } else {
+ auto sgLayoutArr = layout.getSgLayout();
+ sgShape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i) {
+ assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+ sgShape.push_back(shape[i] / sgLayoutArr[i]);
+ }
+ }
+
+ size_t numVariants = adaptor.getOperands().empty()
+ ? 0
+ : adaptor.getOperands().front().size();
+ for (auto &operandVec : adaptor.getOperands())
+ if (operandVec.size() != numVariants)
+ return rewriter.notifyMatchFailure(
+ op, "Operand lists have mismatched sizes");
+
+ SmallVector<Value> newResults;
+
+ auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ VectorType newResultType =
+ origResultType
+ ? VectorType::get(sgShape, origResultType.getElementType())
+ : VectorType::get(sgShape, resultType.getElementType());
+
+ for (size_t i = 0; i < numVariants; ++i) {
+ SmallVector<Value> operands;
+ for (auto &operandVec : adaptor.getOperands())
+ operands.push_back(operandVec[i]);
+
+ auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
+
+ // Copy all attributes except "layout", and add "layout_result_0" with
+ // sgLayout/data dropped
+ for (auto attr : op->getAttrs()) {
+ if (attr.getName() != "layout")
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
+ newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
+
+ newResults.push_back(newOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newResults});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -322,6 +409,57 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
+ // Add elementwise operations that can be distributed to subgroups
+ patterns.add<
+ WgToSgElementwiseOp<arith::AddFOp>, WgToSgElementwiseOp<arith::SubFOp>,
+ WgToSgElementwiseOp<math::ExpOp>, WgToSgElementwiseOp<math::SqrtOp>,
+ WgToSgElementwiseOp<math::AbsFOp>, WgToSgElementwiseOp<math::CosOp>,
+ WgToSgElementwiseOp<math::CoshOp>, WgToSgElementwiseOp<math::AcosOp>,
+ WgToSgElementwiseOp<math::AcoshOp>, WgToSgElementwiseOp<math::SinOp>,
+ WgToSgElementwiseOp<math::SinhOp>, WgToSgElementwiseOp<math::AsinOp>,
+ WgToSgElementwiseOp<math::AsinhOp>, WgToSgElementwiseOp<math::TanOp>,
+ WgToSgElementwiseOp<math::TanhOp>, WgToSgElementwiseOp<math::AtanOp>,
+ WgToSgElementwiseOp<math::Atan2Op>, WgToSgElementwiseOp<math::AtanhOp>,
+ WgToSgElementwiseOp<math::ErfOp>, WgToSgElementwiseOp<math::LogOp>,
+ WgToSgElementwiseOp<math::Log2Op>, WgToSgElementwiseOp<math::FloorOp>,
+ WgToSgElementwiseOp<math::CeilOp>, WgToSgElementwiseOp<math::PowFOp>,
+ WgToSgElementwiseOp<math::RsqrtOp>, WgToSgElementwiseOp<arith::NegFOp>,
+ WgToSgElementwiseOp<arith::AddIOp>, WgToSgElementwiseOp<arith::SubIOp>,
+ WgToSgElementwiseOp<arith::MulFOp>, WgToSgElementwiseOp<arith::MulIOp>,
+ WgToSgElementwiseOp<arith::ShLIOp>, WgToSgElementwiseOp<arith::ShRSIOp>,
+ WgToSgElementwiseOp<arith::ShRUIOp>, WgToSgElementwiseOp<arith::DivFOp>,
+ WgToSgElementwiseOp<arith::DivSIOp>, WgToSgElementwiseOp<arith::DivUIOp>,
+ WgToSgElementwiseOp<arith::MaximumFOp>,
+ WgToSgElementwiseOp<arith::MinimumFOp>,
+ WgToSgElementwiseOp<arith::RemSIOp>, WgToSgElementwiseOp<arith::RemUIOp>,
+ WgToSgElementwiseOp<arith::TruncFOp>,
+ WgToSgElementwiseOp<arith::TruncIOp>, WgToSgElementwiseOp<arith::ExtFOp>,
+ WgToSgElementwiseOp<arith::ExtSIOp>, WgToSgElementwiseOp<arith::ExtUIOp>,
+ WgToSgElementwiseOp<arith::SIToFPOp>,
+ WgToSgElementwiseOp<arith::UIToFPOp>,
+ WgToSgElementwiseOp<arith::FPToSIOp>,
+ WgToSgElementwiseOp<arith::FPToUIOp>,
+ WgToSgElementwiseOp<arith::IndexCastUIOp>,
+ WgToSgElementwiseOp<arith::IndexCastOp>,
+ WgToSgElementwiseOp<arith::BitcastOp>, WgToSgElementwiseOp<arith::CmpIOp>,
+ WgToSgElementwiseOp<arith::CmpFOp>, WgToSgElementwiseOp<arith::AndIOp>,
+ WgToSgElementwiseOp<arith::CeilDivSIOp>,
+ WgToSgElementwiseOp<arith::CeilDivUIOp>,
+ WgToSgElementwiseOp<arith::FloorDivSIOp>,
+ WgToSgElementwiseOp<arith::MaxNumFOp>,
+ WgToSgElementwiseOp<arith::MaxSIOp>, WgToSgElementwiseOp<arith::MaxUIOp>,
+ WgToSgElementwiseOp<arith::MinNumFOp>,
+ WgToSgElementwiseOp<arith::MinSIOp>, WgToSgElementwiseOp<arith::MinUIOp>,
+ WgToSgElementwiseOp<arith::OrIOp>, WgToSgElementwiseOp<arith::RemFOp>,
+ WgToSgElementwiseOp<arith::SelectOp>, WgToSgElementwiseOp<arith::XOrIOp>,
+ WgToSgElementwiseOp<math::AbsIOp>, WgToSgElementwiseOp<math::CbrtOp>,
+ WgToSgElementwiseOp<math::CopySignOp>, WgToSgElementwiseOp<math::CtPopOp>,
+ WgToSgElementwiseOp<math::ErfcOp>, WgToSgElementwiseOp<math::Exp2Op>,
+ WgToSgElementwiseOp<math::ExpM1Op>, WgToSgElementwiseOp<math::FPowIOp>,
+ WgToSgElementwiseOp<math::IPowIOp>, WgToSgElementwiseOp<math::Log10Op>,
+ WgToSgElementwiseOp<math::Log1pOp>, WgToSgElementwiseOp<math::RoundOp>,
+ WgToSgElementwiseOp<math::RoundEvenOp>,
+ WgToSgElementwiseOp<math::TruncOp>>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -368,6 +506,31 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
return isLegal(layout);
});
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Handle unary and binary operations
+ if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
+ return true;
+
+ // check if input and output are vectors
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType || resultType.getRank() != 2)
+ return true;
+
+ // Check if all operands are vectors
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != 2 ||
+ operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
+ op->getAttrOfType<xegpu::LayoutAttr>("layout"));
+ return isLegal(layout);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..85767f4f2bd67
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,1048 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @elementwise_ops {
+ // CHECK-LABEL: elemwise_ops
+ gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // Floating point ops
+ // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.absf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.erf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.floor {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.ceil {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ %addf = arith.addf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %subf = arith.subf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %exp = math.exp %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sqrt = math.sqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %absf = math.absf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cos = math.cos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cosh = math.cosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acos = math.acos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acosh = math.acosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sin = math.sin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sinh = math.sinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asin = math.asin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asinh = math.asinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tan = math.tan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tanh = math.tanh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan = math.atan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan2 = math.atan2 %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atanh = math.a...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/142797
More information about the Mlir-commits
mailing list