Skip to content

Commit 948eab3

Browse files
committed
fix faulty polygon tests for grid with 180 meridians for shtns backend.
1 parent 93c468d commit 948eab3

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

src/shxarray/geom/polygons.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def polygon2sh(polygeom,nmax:int=100,auxcoord=None,engine="shlib",**kwargs) ->xr
4141
if type(polygeom) != gpd.GeoSeries:
4242
polygeom=gpd.GeoSeries(polygeom)
4343

44+
4445
#create a dense enough grid encompassing all polgyons to use for spherical harmonic synthesis
4546
# heuristic way to figure out the resolution based on nmax
4647
dslonlat=xr.Dataset.sh.lonlat_grid(nmax,engine=engine)
@@ -66,9 +67,14 @@ def polygon2sh(polygeom,nmax:int=100,auxcoord=None,engine="shlib",**kwargs) ->xr
6667

6768

6869
dtmp=xr.DataArray(np.zeros([dslonlat.sizes['lon'],dslonlat.sizes['lat'],len(polygeom)]),coords=coords,dims=dims).stack(lonlat=("lon","lat"))
69-
70-
#create a geoDataframe of points from the grid
71-
ggrd=gpd.GeoDataFrame(geometry=[Point(lon,lat) for lon,lat in dtmp.lonlat.values],crs=4326)
70+
if dtmp.lon.min() < 0:
71+
#grid has 0 central meridian already
72+
#create a geoDataframe of points from the grid (lon is already with 0 central meridian
73+
ggrd=gpd.GeoDataFrame(geometry=[Point(lon,lat) for lon,lat in dtmp.lonlat.values],crs=4326)
74+
else:
75+
#grid has 180 central meridian
76+
#create a geoDataframe of points from the grid and convert lon to have 0 central meridian
77+
ggrd=gpd.GeoDataFrame(geometry=[Point((lon+180)%360-180,lat) for lon,lat in dtmp.lonlat.values],crs=4326)
7278

7379
if polygeom.crs != ggrd.crs:
7480
#possibly convert the lon/lat grid in the desired projection before doing the polygon test
@@ -77,7 +83,11 @@ def polygon2sh(polygeom,nmax:int=100,auxcoord=None,engine="shlib",**kwargs) ->xr
7783

7884
#query using a spatial index and set values to 1
7985
shxlogger.info("Masking and gridding polygons")
80-
for i,poly in enumerate(polygeom):
86+
for i,poly in enumerate(polygeom):
87+
# if i == 18:
88+
# breakpoint()
89+
# from IPython.core.debugger import set_trace
90+
# set_trace()
8191
idx=ggrd.sindex.query(poly,predicate="contains")
8292
dtmp[i,idx]=1.0
8393

src/shxarray/kernels/anisokernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from shxarray.shlib import Ynm
99
from shxarray.core.sh_indexing import SHindexBase
1010
import sparse
11-
11+
import numpy as np
1212
from packaging import version
1313

1414

@@ -48,8 +48,12 @@ def __call__(self,dain:xr.DataArray):
4848

4949
if self.nmax < dain.sh.nmax:
5050
raise RuntimeError("Input data has higher degree than kernel, cannot apply kernel operator to object")
51+
52+
if type(dain.data) != np.ndarray:
53+
dain=dain.compute()
5154

5255
daout=xr.dot(self._dskernel.mat,dain,dims=[SHindexBase.name])
56+
5357
#rename nm and convert to dense array
5458
daout=daout.sh.toggle_nm()
5559
if self.useDask:

src/shxarray/signal/basinav.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ def __call__(self, datws,**kwargs):
4444
leakage=leakage_corr_vishwa2016(datws, self._dabin, self._filtername,engine=engine)
4545
da_av=(da_av-leakage)*dascales
4646

47-
return da_av
47+
return da_av.drop_vars(['n','m','nm'])
4848

src/shxarray/signal/leakage_vishwa2016.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy.signal import hilbert
99
from shxarray.kernels import getSHfilter
1010
from shxarray.exp.multiply import multiply
11-
11+
from shxarray.core.logging import shxlogger
1212

1313
def leakage_corr_vishwa2016(datws, dabasins, filtername,engine='shlib'):
1414
"""
@@ -59,9 +59,8 @@ def leakage_corr_vishwa2016(datws, dabasins, filtername,engine='shlib'):
5959
filterOp = getSHfilter(filtername,nmax=datws.sh.nmax)
6060
datws_f=filterOp(datws)
6161

62-
62+
6363
Ic_f=((daknm@datws_f)/dabasins.sel(n=0,m=0)).interp({timedim:time_eq})
64-
6564
#compute double filtered leakage signal
6665
datws_ff=filterOp(datws_f)
6766
Ic_ff=((daknm@datws_ff)/dabasins.sel(n=0,m=0)).interp({timedim:time_eq})
@@ -79,7 +78,11 @@ def leakage_corr_vishwa2016(datws, dabasins, filtername,engine='shlib'):
7978
A[:,0]=1
8079
A[:,1]=Ic_ff.isel({auxdim:i}).data
8180
A[:,2]=np.imag(hilbert(Ic_ff.isel({auxdim:i}).data))
82-
reg_fit,_,_,_=np.linalg.lstsq(A,Ic_f.isel({auxdim:i}).data)
81+
try:
82+
reg_fit,_,_,_=np.linalg.lstsq(A,Ic_f.isel({auxdim:i}).data)
83+
except np.linalg.LinAlgError:
84+
shxlogger.warning(f"Least squares fit failed for basin {i}, phase for leakage could not be computed, assuming np phase shift")
85+
reg_fit=[0,1,0]
8386
# compute phase=atan(c/b) to get the phase shift and take the complex exponential
8487
phase_exp.append(np.exp(-1j*(np.atan2(reg_fit[2],reg_fit[1]))))
8588

0 commit comments

Comments
 (0)