@@ -86,6 +86,21 @@ class PCA(sklearn.decomposition.PCA):
8686 If None, the random number generator is the RandomState instance used
8787 by `da.random`. Used when ``svd_solver`` == 'randomized'.
8888
89+ center : bool, optional (default True)
90+ When True (the default), the underlying data gets centered at zero
91+ by subtracting the mean of the data from the data itself.
92+
93+ PCA is performed on centered data due to its being a regression model,
94+ without an intercept. As such, its principal components originate at the
95+ origin of the transformed space.
96+
97+ ``center=False`` may be employed when performing PCA on already
98+ centered data.
99+
100+ Since centering is a required step as part of whitening, ``center`` set
101+ to False and ``whiten`` set to True is a combination which may result in
102+ unexpected behavior, if performed on not previously centered data.
103+
89104 Attributes
90105 ----------
91106 components_ : array, shape (n_components, n_features)
@@ -152,18 +167,27 @@ class PCA(sklearn.decomposition.PCA):
152167 PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
153168 svd_solver='auto', tol=0.0, whiten=False)
154169 >>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
155- [ 0.99244... 0.00755... ]
170+ [0.99244289 0.00755711 ]
156171 >>> print(pca.singular_values_) # doctest: +ELLIPSIS
157- [ 6.30061... 0.54980... ]
172+ [6.30061232 0.54980396 ]
158173
159174 >>> pca = PCA(n_components=2, svd_solver='full')
160175 >>> pca.fit(dX) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
161176 PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
162177 svd_solver='full', tol=0.0, whiten=False)
163178 >>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
164- [ 0.99244... 0.00755...]
179+ [0.99244289 0.00755711]
180+ >>> print(pca.singular_values_) # doctest: +ELLIPSIS
181+ [6.30061232 0.54980396]
182+
183+ >>> dX_mean_0 = dX - dX.mean(axis=0)
184+ >>> pca = PCA(n_components=2, svd_solver='full', center=False)
185+ >>> pca.fit(dX_mean_0)
186+ PCA(center=False, n_components=2, svd_solver='full')
187+ >>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
188+ [0.99244289 0.00755711]
165189 >>> print(pca.singular_values_) # doctest: +ELLIPSIS
166- [ 6.30061... 0.54980... ]
190+ [6.30061232 0.54980396 ]
167191
168192 Notes
169193 -----
@@ -175,6 +199,10 @@ class PCA(sklearn.decomposition.PCA):
175199 ``dask.linalg.svd_compressed``.
176200 * n_components : ``n_components='mle'`` is not allowed.
177201 Fractional ``n_components`` between 0 and 1 is not allowed.
202+ * center : if ``True`` (the default), automatically center input data before
203+ performing PCA.
204+ Set this parameter to ``False``, if the input data have already been
205+ centered before running ``fit()``.
178206 """
179207
180208 def __init__ (
@@ -186,10 +214,12 @@ def __init__(
186214 tol = 0.0 ,
187215 iterated_power = 0 ,
188216 random_state = None ,
217+ center = True ,
189218 ):
190219 self .n_components = n_components
191220 self .copy = copy
192221 self .whiten = whiten
222+ self .center = center
193223 self .svd_solver = svd_solver
194224 self .tol = tol
195225 self .iterated_power = iterated_power
@@ -198,6 +228,7 @@ def __init__(
198228 def fit (self , X , y = None ):
199229 if not dask .is_dask_collection (X ):
200230 raise TypeError (_TYPE_MSG .format (type (X )))
231+
201232 self ._fit (X )
202233 self .n_features_in_ = X .shape [1 ]
203234 return self
@@ -266,8 +297,10 @@ def _fit(self, X):
266297
267298 solver = self ._get_solver (X , n_components )
268299
269- self .mean_ = X .mean (0 )
270- X -= self .mean_
300+ self .mean_ = X .mean (axis = 0 )
301+
302+ if self .center :
303+ X -= self .mean_
271304
272305 if solver in {"full" , "tsqr" }:
273306 U , S , V = da .linalg .svd (X )
@@ -370,14 +403,20 @@ def transform(self, X):
370403 X_new : array-like, shape (n_samples, n_components)
371404
372405 """
373- check_is_fitted (self , ["mean_" , "components_" ])
406+ check_is_fitted (self , "components_" )
407+
408+ if self .whiten :
409+ check_is_fitted (self , "explained_variance_" )
410+
411+ if self .center :
412+ check_is_fitted (self , "mean_" )
413+ if self .mean_ is not None :
414+ X -= self .mean_
374415
375- # X = check_array(X)
376- if self .mean_ is not None :
377- X = X - self .mean_
378416 X_transformed = da .dot (X , self .components_ .T )
379417 if self .whiten :
380418 X_transformed /= np .sqrt (self .explained_variance_ )
419+
381420 return X_transformed
382421
383422 def fit_transform (self , X , y = None ):
@@ -396,7 +435,6 @@ def fit_transform(self, X, y=None):
396435 X_new : array-like, shape (n_samples, n_components)
397436
398437 """
399- # X = check_array(X)
400438 if not dask .is_dask_collection (X ):
401439 raise TypeError (_TYPE_MSG .format (type (X )))
402440 U , S , V = self ._fit (X )
@@ -431,18 +469,25 @@ def inverse_transform(self, X):
431469 If whitening is enabled, inverse_transform does not compute the
432470 exact inverse operation of transform.
433471 """
434- check_is_fitted (self , "mean_" )
472+ check_is_fitted (self , "components_" )
473+
474+ if self .center :
475+ check_is_fitted (self , "mean_" )
476+ offset = self .mean_
477+ else :
478+ offset = 0
435479
436480 if self .whiten :
481+ check_is_fitted (self , "explained_variance_" )
437482 return (
438483 da .dot (
439484 X ,
440485 np .sqrt (self .explained_variance_ [:, np .newaxis ]) * self .components_ ,
441486 )
442- + self . mean_
487+ + offset
443488 )
444- else :
445- return da .dot (X , self .components_ ) + self . mean_
489+
490+ return da .dot (X , self .components_ ) + offset
446491
447492 def score_samples (self , X ):
448493 """Return the log-likelihood of each sample.
@@ -463,8 +508,11 @@ def score_samples(self, X):
463508 """
464509 check_is_fitted (self , "mean_" )
465510
466- # X = check_array(X)
467- Xr = X - self .mean_
511+ if self .center :
512+ Xr = X - self .mean_
513+ else :
514+ Xr = X
515+
468516 n_features = X .shape [1 ]
469517 precision = self .get_precision () # [n_features, n_features]
470518 log_like = - 0.5 * (Xr * (da .dot (Xr , precision ))).sum (axis = 1 )
0 commit comments