-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path2-planes-flow-mesh.py
More file actions
executable file
·148 lines (120 loc) · 4.87 KB
/
2-planes-flow-mesh.py
File metadata and controls
executable file
·148 lines (120 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
# takes a pair of slices and aligns them
# a derivative of https://github.com/google-research/sofima/blob/main/notebooks/em_alignment.ipynb
# this first part just does the GPU intensive stuff
# bsub -Phess -n4 -gpu "num=1" -q gpu_h100 -Is /bin/zsh
# conda run -n multi-sem --no-capture-output python ./2-planes-flow-mesh.py data-hess-2-planes /nrs/hess/data/hess_wafers_60_61/export/zarr_datasets/surface-align/run_20251219_110000/pass03-scale2 flat-w61_serial_080_to_089-w61_s080_r00-top-face.zarr flat-w61_serial_070_to_079-w61_s079_r00-bot-face.zarr 160 8 1024
import os
import argparse
import time
import importlib
import numpy as np
from sofima import flow_field
from sofima import flow_utils
from sofima import mesh
# Parse command line arguments
parser = argparse.ArgumentParser(
description="Takes a pair of slices and aligns them - GPU intensive processing"
)
parser.add_argument(
"data_loader",
help="Data loader module name, e.g., data-test-2-planes"
)
parser.add_argument(
"basepath",
help="filepath to stitched planes"
)
parser.add_argument(
"top",
help="filename of top of one slab"
)
parser.add_argument(
"bot",
help="filename of bottom of an adjacent slab"
)
parser.add_argument(
"patch_size",
type=str,
help="Side length of (square) patch for processing (in pixels, e.g., 32)",
)
parser.add_argument(
"stride",
type=str,
help="Distance of adjacent patches (in pixels, e.g., 8)"
)
parser.add_argument(
"batch_size",
type=int,
help="Batch size for processing"
)
args = parser.parse_args()
data_loader = args.data_loader
basepath = args.basepath
top = args.top
bot = args.bot
patch_size = args.patch_size
stride = args.stride
batch_size = args.batch_size
print("data_loader =", data_loader)
print("basepath =", basepath)
print("top =", top)
print("bot =", bot)
print("patch_size =", patch_size)
print("stride =", stride)
print("batch_size =", batch_size)
patch_size_int = [int(x) for x in args.patch_size.split(',')]
stride_int = [int(x) for x in args.stride.split(',')]
if len(patch_size_int) != len(stride_int):
print("lengths of patch_size and stride must be equal")
exit()
data = importlib.import_module(os.path.basename(data_loader))
ttop, tbot = data.load_data(basepath, top, bot)
#calculate the flow fields
# uses GPU
mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
t0 = time.time()
flow = mfc.flow_field(ttop, tbot,
(patch_size_int[0], patch_size_int[0]), (stride_int[0], stride_int[0]),
batch_size=batch_size)
print("mean of flows = ", np.nanmean(np.abs(flow[np.isfinite(flow)])))
print("flow_field took", time.time() - t0, "sec")
for i in range(1,len(patch_size_int)):
t0 = time.time()
flow = mfc.flow_field(ttop, tbot,
(patch_size_int[i], patch_size_int[i]), (stride_int[i], stride_int[i]),
batch_size=batch_size,
pre_targeting_field = flow[:2,::],
pre_targeting_step = (stride_int[i-1], stride_int[i-1]))
print("mean of flows = ", np.nanmean(np.abs(flow[np.isfinite(flow)])))
print("flow_field took", time.time() - t0, "sec")
# the first two channels store the XY components of the flow vector, and the
# two remaining channels are measures of estimation quality (see
# sofima.flow_field._batched_peaks for more info)
flow = np.array(flow)[np.newaxis,:]
# Convert to [channels, z, y, x].
flow = np.transpose(flow, [1, 0, 2, 3])
# Pad to account for the edges of the images where there is insufficient
# context to estimate flow.
pad = patch_size_int[-1] // 2 // stride_int[-1]
flow = np.pad(flow, [[0, 0], [0, 0], [pad, pad], [pad, pad]], constant_values=np.nan)
# remove uncertain flow estimates by replacing them with NaNs
t0 = time.time()
flow_clean = flow_utils.clean_flow(flow, min_peak_ratio=1.6, min_peak_sharpness=1.6,
max_magnitude=80, max_deviation=20)
print("clean_flow took", time.time() - t0, "sec")
### multi-resolution flow fields would be merged here
# find a configuration of the imagery that is compatible with the estimated
# flow field and preserves the original geometry as much as possible.
config = mesh.IntegrationConfig(dt=0.001, gamma=0.0, k0=0.01, k=0.1,
stride=(stride_int[-1], stride_int[-1]),
num_iters=1000, max_iters=100000,
stop_v_max=0.005, dt_max=1000, start_cap=0.01,
final_cap=10, prefer_orig_order=True)
solved = np.zeros_like(flow_clean)
# also uses GPU
t0 = time.time()
solved, e_kin, num_steps = mesh.relax_mesh(solved, flow_clean, config)
print("relax_mesh took", time.time() - t0, "sec")
params = 'patch'+patch_size+'.stride'+stride+'.top'+os.path.splitext(top)[0]
data.save_flow(flow_clean, basepath, params)
data.save_mesh(solved, basepath, params)