Skip to content

Commit 22e2583

Browse files
committed
Added functions
1 parent df7803f commit 22e2583

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
name = "general_class_balancer"
44

5-
version = "0.0.6"
5+
version = "0.0.7"
66

77
dependencies = [
88
"numpy>=1.22.0",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
import importlib.metadata
22

3-
__version__ = "0.0.6" #importlib.metadata.version(__package__ or __name__)
3+
__version__ = "0.0.7" #importlib.metadata.version(__package__ or __name__)

src/general_class_balancer/general_class_balancer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@ def is_nan(k,inc_null_str=False):
4747
else:
4848
return False
4949

50+
def bucketize(arr,n_buckets):
51+
non_arr_list = []
52+
max_ = -np.Inf
53+
min_ = np.Inf
54+
for i in range(len(arr)):
55+
if not is_nan(arr[i]):
56+
if isinstance(arr[i],str): return arr
57+
non_arr_list.append(arr[i])
58+
if arr[i] > max_: max_ = arr[i]
59+
if arr[i] < min_: min_ = arr[i]
60+
bucketized_list = np.array(["NaN" for i in range(len(arr))],
61+
dtype=np.dtype(object))
62+
non_arr_list = sorted(non_arr_list)
63+
skips = int(len(non_arr_list)/float(n_buckets)) + 1
64+
buckets = np.array(non_arr_list[::skips])
65+
range_dist=((np.arange(n_buckets)/float(n_buckets-1))*(max_-min_))+min_
66+
while len(buckets) < n_buckets:
67+
print(buckets)
68+
buckets = np.array([buckets[0]] + list(buckets))
69+
buckets = (range_dist + buckets) / 2
70+
for i in range(len(arr)):
71+
if not is_nan(arr[i]):
72+
for j in range(len(buckets)-1):
73+
if arr[i] > buckets[j] and \
74+
arr[i] <= buckets[j+1]:
75+
bucketized_list[i] = str(j)
76+
break
77+
return bucketized_list
78+
79+
#
80+
5081
# This method uses prime numbers to speed up datapoint matching. Each bucket
5182
# gets a prime number, and each datapoint is assigned a product of these primes.
5283
# These are then matched with one another.

0 commit comments

Comments
 (0)