Skip to content

Commit d22fd27

Browse files
committed
Unit testing
1 parent 3e26a62 commit d22fd27

File tree

5 files changed

+357
-98
lines changed

5 files changed

+357
-98
lines changed

examples/Scratch/global_test.py

Lines changed: 122 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,6 @@
2828
| / | /
2929
4--------------5
3030
31-
* | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7
32-
----------------------------------
33-
0 | | | | | | | |
34-
----------------------------------
35-
1 | | | | | | | |
36-
----------------------------------
37-
2 | | | | | | | |
38-
----------------------------------
39-
3 | | | | | | | |
40-
----------------------------------
41-
4 | | | | | | | |
42-
----------------------------------
43-
5 | | | | | | | |
44-
----------------------------------
45-
6 | | | | | | | |
46-
----------------------------------
47-
7 | | | | | | | |
48-
4931
A two-element mesh:
5032
5133
Element 1 Element 2
@@ -59,33 +41,6 @@
5941
| / | / | /
6042
6--------------7--------------8
6143
62-
63-
* | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11
64-
---------------------------------------------------
65-
0 |e1 |e1 | | | | | | | | | |
66-
---------------------------------------------------
67-
1 |e1 |e12| | | | | | | | | |
68-
---------------------------------------------------
69-
2 | | |e2 | | | | | | | | |
70-
---------------------------------------------------
71-
3 | | | |e1 | | | | | | | |
72-
---------------------------------------------------
73-
4 | | | | |e12| | | | | | |
74-
---------------------------------------------------
75-
5 | | | | | |e2 | | | | | |
76-
---------------------------------------------------
77-
6 | | | | | | |e1 | | | | |
78-
---------------------------------------------------
79-
7 | | | | | | | |e12| | | |
80-
---------------------------------------------------
81-
8 | | | | | | | | |e2 | | |
82-
---------------------------------------------------
83-
9 | | | | | | | | | |e1 | |
84-
---------------------------------------------------
85-
10| | | | | | | | | | |e12 |
86-
---------------------------------------------------
87-
11| | | | | | | | | | | |e2
88-
8944
"""
9045

9146
# Standard imports
@@ -94,15 +49,15 @@
9449

9550
# Local imports
9651
from SPI2py.models.physics.distributed.mesh import generate_mesh
97-
from SPI2py.models.physics.distributed.assembly import assemble_global_stiffness
98-
52+
from SPI2py.models.physics.distributed.assembly import assemble_sparse_global_stiffness
9953

10054
w = 1 # x
10155
h = 1 # y
10256
d = 1 # Z
10357

58+
# Assemble the global stiffness matrix and convert it to a dense matrix
10459
nodes_2e, elements_2e, _, _, _, _, _, _, _ = generate_mesh(0, 2*w, 0, h, 0, d, element_size=1.0)
105-
Ke_flat, elem_indices, rows_flat, cols_flat, n_nodes, n_elem = assemble_global_stiffness(nodes_2e, elements_2e, base_k=1.0)
60+
Ke_flat, elem_indices, rows_flat, cols_flat, n_nodes, n_elem = assemble_sparse_global_stiffness(nodes_2e, elements_2e, base_k=1.0)
10661
indices = jnp.stack([rows_flat, cols_flat], axis=-1)
10762
K_global = K_pf = BCOO((Ke_flat, indices), shape=(n_nodes, n_nodes))
10863
K_global_dense = K_global.todense()
@@ -133,9 +88,6 @@
13388
[H, G, F, E, D, C, B, A]])
13489

13590

136-
# K_global_expected = jnp.zeros((12, 12))
137-
# K_global_expected[]
138-
13991

14092
# Confirm that all entries that should be zeros are zeros
14193

@@ -190,3 +142,122 @@
190142
assert jnp.all(K_global_dense[non_zeros_i, non_zeros_j] != 0.0)
191143

192144

145+
146+
147+
# Now spot check some values; K_e1 == K_e2 == K_ex
148+
149+
# K[0,0] == K_e1[0,0]
150+
assert jnp.isclose(K_global_dense[0, 0], k_ex[0, 0])
151+
152+
# K[1,1] == K_e1[1,1] + K_e2[0,0]
153+
assert jnp.isclose(K_global_dense[1, 1], k_ex[1, 1] + k_ex[0, 0])
154+
155+
# K[0,9] == K_e1[0,7] and K[9,0] == K_e2[7,0]
156+
assert jnp.isclose(K_global_dense[0, 9], k_ex[0, 7])
157+
158+
# K[11,11] == K_e2[6,6]
159+
assert jnp.isclose(K_global_dense[11, 11], k_ex[6, 6])
160+
161+
162+
163+
164+
165+
166+
167+
168+
169+
170+
171+
172+
173+
# # Verify
174+
# cond_00_00 = 1
175+
# cond_00_01 = 1
176+
# cond_00_02 = 1
177+
# cond_00_03 = 1
178+
# cond_00_04 = 1
179+
# cond_00_05 = 1
180+
# cond_00_06 = 1
181+
# cond_00_07 = 1
182+
# cond_00_08 = 1
183+
# cond_00_09 = 1
184+
# cond_00_10 = 1
185+
# cond_00_11 = 1
186+
#
187+
# cond_01_00 = 1
188+
# cond_01_01 = 1
189+
# cond_01_02 = 1
190+
# cond_01_03 = 1
191+
# cond_01_04 = 1
192+
# cond_01_05 = 1
193+
# cond_01_06 = 1
194+
# cond_01_07 = 1
195+
# cond_01_08 = 1
196+
# cond_01_09 = 1
197+
# cond_01_10 = 1
198+
# cond_01_11 = 1
199+
#
200+
# cond_02_00 = 1
201+
# cond_02_01 = 1
202+
# cond_02_02 = 1
203+
# cond_02_03 = 1
204+
# cond_02_04 = 1
205+
# cond_02_05 = 1
206+
# cond_02_06 = 1
207+
# cond_02_07 = 1
208+
# cond_02_08 = 1
209+
# cond_02_09 = 1
210+
# cond_02_10 = 1
211+
# cond_02_11 = 1
212+
#
213+
# cond_03_00 = 1
214+
# cond_03_01 = 1
215+
# cond_03_02 = 1
216+
# cond_03_03 = 1
217+
# cond_03_04 = 1
218+
# cond_03_05 = 1
219+
# cond_03_06 = 1
220+
# cond_03_07 = 1
221+
# cond_03_08 = 1
222+
# cond_03_09 = 1
223+
# cond_03_10 = 1
224+
# cond_03_11 = 1
225+
#
226+
# cond_03_00 = 1
227+
# cond_03_01 = 1
228+
# cond_03_02 = 1
229+
# cond_03_03 = 1
230+
# cond_03_04 = 1
231+
# cond_03_05 = 1
232+
# cond_03_06 = 1
233+
# cond_03_07 = 1
234+
# cond_03_08 = 1
235+
# cond_03_09 = 1
236+
# cond_03_10 = 1
237+
# cond_03_11 = 1
238+
#
239+
# cond_04_00 = 1
240+
# cond_04_01 = 1
241+
# cond_04_02 = 1
242+
# cond_04_03 = 1
243+
# cond_04_04 = 1
244+
# cond_04_05 = 1
245+
# cond_04_06 = 1
246+
# cond_04_07 = 1
247+
# cond_04_08 = 1
248+
# cond_04_09 = 1
249+
# cond_04_10 = 1
250+
# cond_04_11 = 1
251+
#
252+
# cond_00_00 = 1
253+
# cond_00_01 = 1
254+
# cond_00_02 = 1
255+
# cond_00_03 = 1
256+
# cond_00_04 = 1
257+
# cond_00_05 = 1
258+
# cond_00_06 = 1
259+
# cond_00_07 = 1
260+
# cond_00_08 = 1
261+
# cond_00_09 = 1
262+
# cond_00_10 = 1
263+
# cond_00_11 = 1

examples/Scratch/scratch_test.py renamed to examples/Scratch/scratch_t.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import jax.numpy as jnp
1+
# Standard imports
22
import matplotlib.pyplot as plt
3-
from SPI2py.models.physics.distributed.mesh import generate_mesh, find_active_nodes, find_face_nodes
4-
from SPI2py.models.physics.distributed.element import assemble_local_stiffness_matrix_scalar
3+
import jax.numpy as jnp
4+
from jax.experimental.sparse import BCOO
5+
6+
# Local imports
7+
from SPI2py.models.physics.distributed.mesh import generate_mesh, find_face_nodes
58
from SPI2py.models.physics.distributed.assembly import assemble_base_global_system_partition, apply_bc_partition
6-
from SPI2py.models.physics.distributed.quadrature import shape_functions
79
from SPI2py.models.physics.distributed.solver import solve_system_partition
8-
10+
from SPI2py.models.physics.distributed.assembly import assemble_sparse_global_stiffness
911

1012

1113
"""
@@ -26,19 +28,29 @@
2628

2729
element_size = 0.25
2830

31+
# Generate the mesh
2932
nodes, elements, centers, nx, ny, nz, lx, ly, lz = generate_mesh(x_min, x_max, y_min, y_max, z_min, z_max,
3033
element_size=element_size)
3134

35+
36+
3237
densities = jnp.ones(elements.shape[0]) # uniform density
3338
heat_loads = jnp.zeros(elements.shape[0]) # no heat generation
3439

3540
# Find active nodes and face nodes
36-
dirichlet_nodes = find_face_nodes(nodes, jnp.array([0.0, 0.0, -1.0]))
37-
robin_nodes = find_face_nodes(nodes, jnp.array([0.0, 0.0, 1.0]))
41+
top_normal = jnp.array([0, 1, 0])
42+
bottom_normal = jnp.array([0, -1, 0])
43+
44+
dirichlet_nodes = find_face_nodes(nodes, bottom_normal)
45+
robin_nodes = find_face_nodes(nodes, top_normal)
46+
3847
robin_area = element_size * element_size
3948

4049
dirichlet_temperature = T_fixed * jnp.ones(dirichlet_nodes.shape[0])
4150

51+
52+
53+
4254
# Initialize global stiffness matrix and force vector
4355
K_base, f_base, elem_indices = assemble_base_global_system_partition(nodes, elements,
4456
base_k=k,

src/SPI2py/models/physics/distributed/assembly.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
# @jit
10-
def assemble_global_stiffness(nodes, elements, base_k):
10+
def assemble_sparse_global_stiffness(nodes, elements, base_k):
1111
"""
1212
Assemble the global base stiffness matrix.
1313
Indices are included for mapping element densities to each of their local stiffness matrix contributions.
@@ -60,6 +60,10 @@ def assemble_global_stiffness(nodes, elements, base_k):
6060
return Ke_flat, elem_indices, rows_flat, cols_flat, n_nodes, n_elem
6161

6262

63+
def partition_sparse_global_system():
64+
pass
65+
66+
6367
# @jit
6468
def assemble_base_global_system_partition(nodes, elements,
6569
base_k,
@@ -70,8 +74,8 @@ def assemble_base_global_system_partition(nodes, elements,
7074
idx_p = d_nodes
7175
idx_f = jnp.setdiff1d(idx, idx_p)
7276

73-
Ke_flat, elem_indices_full, rows_flat, cols_flat, n_nodes, n_elem = assemble_global_stiffness(nodes, elements,
74-
base_k)
77+
Ke_flat, elem_indices_full, rows_flat, cols_flat, n_nodes, n_elem = assemble_sparse_global_stiffness(nodes, elements,
78+
base_k)
7579
indices = jnp.stack([rows_flat, cols_flat], axis=-1) # shape: (nnz, 2)
7680

7781
# Partition the contributions into four blocks based on free (idx_f) and prescribed (idx_p) DOF indices.
@@ -134,8 +138,8 @@ def remap_local(global_inds, idx_array):
134138

135139
# Apply boundary conditions (both Robin and Dirichlet).
136140
K_base, f_base = apply_bc_partition(K_base, f_base,
137-
r_nodes, r_h, r_T_inf, r_area,
138-
idx_f, idx_p)
141+
r_nodes, r_h, r_T_inf, r_area,
142+
idx_f, idx_p)
139143

140144
return K_base, f_base, elem_indices
141145

@@ -195,9 +199,9 @@ def update_global_stiffness_partition(K_base, f_base,
195199
element_densities, element_heat_loads):
196200

197201
# Unpack the partitioned stiffness matrix, load vector, and their indices.
198-
K_ff, K_fp, K_pf, K_pp = K_base
199-
f_f, f_p = f_base
200-
idx_f, idx_p = idx_vector
202+
K_ff, K_fp, K_pf, K_pp = K_base
203+
f_f, f_p = f_base
204+
idx_f, idx_p = idx_vector
201205
ei_ff, ei_fp, ei_pf, ei_pp = ei_matrix
202206

203207
# Flatten the pseudo-densities and heat loads to ensure they are 1D.
@@ -251,7 +255,7 @@ def assemble_base_global_system_penalty(nodes, elements, base_k,
251255
r_nodes, r_h, r_T_inf, r_area,
252256
d_nodes, d_T):
253257

254-
Ke_flat, elem_indices, rows_flat, cols_flat, n_nodes, n_elem = assemble_global_stiffness(nodes, elements, base_k)
258+
Ke_flat, elem_indices, rows_flat, cols_flat, n_nodes, n_elem = assemble_sparse_global_stiffness(nodes, elements, base_k)
255259

256260
# Stack rows and cols to form an index array of shape (n_elem*64, 2).
257261
indices = jnp.stack([rows_flat, cols_flat], axis=-1)

0 commit comments

Comments
 (0)