Skip to content

Commit 7e7875d

Browse files
authored
Add Gaussian cdf as atom (#50)
* normal cdf * broadcast bug fix * run formatter * update version
1 parent 3fe6f5b commit 7e7875d

File tree

6 files changed

+101
-2
lines changed

6 files changed

+101
-2
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
55

66
set(DIFF_ENGINE_VERSION_MAJOR 0)
77
set(DIFF_ENGINE_VERSION_MINOR 1)
8-
set(DIFF_ENGINE_VERSION_PATCH 4)
8+
set(DIFF_ENGINE_VERSION_PATCH 5)
99
set(DIFF_ENGINE_VERSION "${DIFF_ENGINE_VERSION_MAJOR}.${DIFF_ENGINE_VERSION_MINOR}.${DIFF_ENGINE_VERSION_PATCH}")
1010
add_compile_definitions(DIFF_ENGINE_VERSION="${DIFF_ENGINE_VERSION}")
1111

include/elementwise_univariate.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ expr *new_atanh(expr *child);
3737
expr *new_logistic(expr *child);
3838
expr *new_power(expr *child, double p);
3939
expr *new_xexp(expr *child);
40+
expr *new_normal_cdf(expr *child);
4041

4142
/* the jacobian and wsum_hess for elementwise univariate atoms are always
4243
initialized in the same way and implement the chain rule in the same way */

src/affine/broadcast.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ static void jacobian_init(expr *node)
129129
int offset = 0;
130130
for (int i = 0; i < node->d2; i++)
131131
{
132-
int nnz_in_row = Jx->p[i + 1] - Jx->p[i];
133132
for (int j = 0; j < node->d1; j++)
134133
{
134+
int nnz_in_row = Jx->p[j + 1] - Jx->p[j];
135135
J->p[i * node->d1 + j] = offset;
136136
offset += nnz_in_row;
137137
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2026 Daniel Cederberg and William Zhang
3+
*
4+
* This file is part of the DNLP-differentiation-engine project.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
#include "elementwise_univariate.h"
19+
#include <math.h>
20+
21+
#ifndef M_PI
22+
#define M_PI 3.14159265358979323846
23+
#endif
24+
25+
#ifndef M_SQRT2
26+
#define M_SQRT2 1.41421356237309504880
27+
#endif
28+
29+
static const double INV_SQRT_2PI = 0.3989422804014326779399461;
30+
31+
static void forward(expr *node, const double *u)
32+
{
33+
node->left->forward(node->left, u);
34+
35+
double *x = node->left->value;
36+
for (int i = 0; i < node->size; i++)
37+
{
38+
node->value[i] = 0.5 * (1.0 + erf(x[i] / M_SQRT2));
39+
}
40+
}
41+
42+
static void local_jacobian(expr *node, double *vals)
43+
{
44+
double *x = node->left->value;
45+
for (int j = 0; j < node->size; j++)
46+
{
47+
vals[j] = INV_SQRT_2PI * exp(-0.5 * x[j] * x[j]);
48+
}
49+
}
50+
51+
static void local_wsum_hess(expr *node, double *out, const double *w)
52+
{
53+
double *x = node->left->value;
54+
for (int j = 0; j < node->size; j++)
55+
{
56+
/* could avoid recomputing this (like in logistic) */
57+
double phi = INV_SQRT_2PI * exp(-0.5 * x[j] * x[j]);
58+
out[j] = w[j] * (-x[j] * phi);
59+
}
60+
}
61+
62+
expr *new_normal_cdf(expr *child)
63+
{
64+
expr *node = new_elementwise(child);
65+
node->forward = forward;
66+
node->local_jacobian = local_jacobian;
67+
node->local_wsum_hess = local_wsum_hess;
68+
69+
return node;
70+
}

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "forward_pass/composite/test_composite.h"
1616
#include "forward_pass/elementwise/test_exp.h"
1717
#include "forward_pass/elementwise/test_log.h"
18+
#include "forward_pass/elementwise/test_normal_cdf.h"
1819
#include "forward_pass/test_left_matmul_dense.h"
1920
#include "forward_pass/test_matmul.h"
2021
#include "forward_pass/test_prod_axis_one.h"
@@ -101,6 +102,7 @@ int main(void)
101102
mu_run_test(test_promote_scalar_to_vector, tests_run);
102103
mu_run_test(test_exp, tests_run);
103104
mu_run_test(test_log, tests_run);
105+
mu_run_test(test_normal_cdf, tests_run);
104106
mu_run_test(test_composite, tests_run);
105107
mu_run_test(test_sum_axis_neg1, tests_run);
106108
mu_run_test(test_sum_axis_0, tests_run);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <math.h>
2+
#include <stdio.h>
3+
#include <stdlib.h>
4+
5+
#include "affine.h"
6+
#include "elementwise_univariate.h"
7+
#include "expr.h"
8+
#include "minunit.h"
9+
#include "test_helpers.h"
10+
11+
#ifndef M_SQRT2
12+
#define M_SQRT2 1.41421356237309504880
13+
#endif
14+
15+
const char *test_normal_cdf(void)
16+
{
17+
double u[3] = {1.0, 2.0, 3.0};
18+
expr *var = new_variable(3, 1, 0, 3);
19+
expr *node = new_normal_cdf(var);
20+
node->forward(node, u);
21+
/* computed in python */
22+
double correct[3] = {0.8413447460685429, 0.9772498680518208, 0.9986501019683699};
23+
mu_assert("fail", cmp_double_array(node->value, correct, 3));
24+
free_expr(node);
25+
return 0;
26+
}

0 commit comments

Comments
 (0)