From 86db42e97f6f20330b1a54653eeff6814162c39b Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Thu, 23 Apr 2026 02:28:56 +0000 Subject: [PATCH] CUDA: fuse relu + sqr (#22249) --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 +++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cu | 23 +++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ tests/test-backend-ops.cpp | 40 +++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 185956317..1c2c3b4ac 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3592,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -4100,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4ad30fa1f..2aeba26f4 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary GGML_ABORT("Unsupported unary op for fused unary+mul"); } } + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f1dd2183a..81ed873ec 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 828a9c14a..716011316 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3522,6 +3522,40 @@ struct test_add_rms_norm : public test_case { } }; +// GGML_OP_UNARY(RELU) + GGML_OP_SQR (fused operation) +struct test_relu_sqr : public test_case { + const ggml_type type; + const std::array ne; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "RELU_SQR"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_relu_sqr(ggml_type type = GGML_TYPE_F32, + std::array ne = {128, 2, 2, 2}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * r = ggml_relu(ctx, a); + ggml_set_name(r, "relu"); + + ggml_tensor * out = ggml_sqr(ctx, r); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -7311,6 +7345,12 @@ static std::vector> make_test_cases_eval() { } } + // fused relu + sqr (squared ReLU) + for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { + test_cases.emplace_back(new test_relu_sqr(type, { 128, 2, 2, 2 })); + test_cases.emplace_back(new test_relu_sqr(type, { 5, 7, 11, 13 })); + } + // glu ops for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) {