Skip to content

Commit 5968778

Browse files
committed
chore(gpu): add better panic when calling old compression functions
1 parent a5c2485 commit 5968778

File tree

3 files changed

+103
-33
lines changed

3 files changed

+103
-33
lines changed

tfhe/src/high_level_api/booleans/compressed.rs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ use crate::backward_compatibility::booleans::{
22
CompressedFheBoolVersions, InnerCompressedFheBoolVersions,
33
};
44
use crate::conformance::ParameterSetConformant;
5-
use crate::high_level_api::global_state::with_cpu_internal_keys;
5+
use crate::high_level_api::global_state;
6+
use crate::high_level_api::keys::InternalServerKey;
67
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
78
use crate::high_level_api::traits::Tagged;
89
use crate::integer::BooleanBlock;
@@ -74,17 +75,27 @@ impl CompressedFheBool {
7475
let ciphertext = BooleanBlock::new_unchecked(match &self.inner {
7576
InnerCompressedFheBool::Seeded(seeded) => seeded.decompress(),
7677
InnerCompressedFheBool::ModulusSwitched(modulus_switched) => {
77-
with_cpu_internal_keys(|sk| sk.pbs_key().key.decompress(modulus_switched))
78+
global_state::with_internal_keys(|keys| match keys {
79+
InternalServerKey::Cpu(cpu_key) => {
80+
cpu_key.pbs_key().key.decompress(modulus_switched)
81+
}
82+
#[cfg(feature = "gpu")]
83+
InternalServerKey::Cuda(_) => {
84+
panic!("decompress() on FheBool is not supported on GPU, use a CompressedCiphertextList instead");
85+
}
86+
#[cfg(feature = "hpu")]
87+
InternalServerKey::Hpu(_) => {
88+
panic!("decompress() on FheBool is not supported on HPU devices");
89+
}
90+
})
7891
}
7992
});
8093
let mut ciphertext = FheBool::new(
8194
ciphertext,
8295
self.tag.clone(),
8396
ReRandomizationMetadata::default(),
8497
);
85-
8698
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
87-
8899
ciphertext
89100
}
90101
}
@@ -117,16 +128,26 @@ impl Named for CompressedFheBool {
117128

118129
impl FheBool {
119130
pub fn compress(&self) -> CompressedFheBool {
120-
with_cpu_internal_keys(|sk| {
121-
let inner = InnerCompressedFheBool::ModulusSwitched(
122-
sk.pbs_key()
123-
.key
124-
.switch_modulus_and_compress(&self.ciphertext.on_cpu().0),
125-
);
126-
127-
CompressedFheBool {
128-
inner,
129-
tag: sk.tag.clone(),
131+
global_state::with_internal_keys(|keys| match keys {
132+
InternalServerKey::Cpu(cpu_key) => {
133+
let inner = InnerCompressedFheBool::ModulusSwitched(
134+
cpu_key
135+
.pbs_key()
136+
.key
137+
.switch_modulus_and_compress(&self.ciphertext.on_cpu().0),
138+
);
139+
CompressedFheBool {
140+
inner,
141+
tag: cpu_key.tag.clone(),
142+
}
143+
}
144+
#[cfg(feature = "gpu")]
145+
InternalServerKey::Cuda(_) => {
146+
panic!("compress() on FheBool is not supported on GPU, use a CompressedCiphertextList instead");
147+
}
148+
#[cfg(feature = "hpu")]
149+
InternalServerKey::Hpu(_) => {
150+
panic!("compress() on FheBool is not supported on HPU devices");
130151
}
131152
})
132153
}

tfhe/src/high_level_api/integers/signed/compressed.rs

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use crate::backward_compatibility::integers::{
55
};
66
use crate::conformance::ParameterSetConformant;
77
use crate::core_crypto::prelude::SignedNumeric;
8-
use crate::high_level_api::global_state::with_cpu_internal_keys;
8+
use crate::high_level_api::global_state;
99
use crate::high_level_api::integers::signed::base::FheIntConformanceParams;
1010
use crate::high_level_api::integers::{FheInt, FheIntId};
11+
use crate::high_level_api::keys::InternalServerKey;
1112
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
1213
use crate::high_level_api::traits::Tagged;
1314
use crate::integer::block_decomposition::DecomposableInto;
@@ -111,7 +112,19 @@ where
111112
let ciphertext = match &self.ciphertext {
112113
CompressedSignedRadixCiphertext::Seeded(ct) => ct.decompress(),
113114
CompressedSignedRadixCiphertext::ModulusSwitched(ct) => {
114-
with_cpu_internal_keys(|sk| sk.pbs_key().decompress_signed_parallelized(ct))
115+
global_state::with_internal_keys(|keys| match keys {
116+
InternalServerKey::Cpu(cpu_key) => {
117+
cpu_key.pbs_key().decompress_signed_parallelized(ct)
118+
}
119+
#[cfg(feature = "gpu")]
120+
InternalServerKey::Cuda(_) => {
121+
panic!("decompress() on FheInt is not supported on GPU, use a CompressedCiphertextList instead");
122+
}
123+
#[cfg(feature = "hpu")]
124+
InternalServerKey::Hpu(_) => {
125+
panic!("decompress() on FheInt is not supported on HPU devices");
126+
}
127+
})
115128
}
116129
};
117130
FheInt::new(
@@ -180,14 +193,25 @@ where
180193
Id: FheIntId,
181194
{
182195
pub fn compress(&self) -> CompressedFheInt<Id> {
183-
let a = with_cpu_internal_keys(|sk| {
184-
sk.pbs_key()
185-
.switch_modulus_and_compress_signed_parallelized(&self.ciphertext.on_cpu())
186-
});
187-
188-
CompressedFheInt::new(
189-
CompressedSignedRadixCiphertext::ModulusSwitched(a),
190-
self.tag.clone(),
191-
)
196+
global_state::with_internal_keys(|keys| match keys {
197+
InternalServerKey::Cpu(cpu_key) => {
198+
let a = cpu_key
199+
.pbs_key()
200+
.switch_modulus_and_compress_signed_parallelized(&self.ciphertext.on_cpu());
201+
202+
CompressedFheInt::new(
203+
CompressedSignedRadixCiphertext::ModulusSwitched(a),
204+
self.tag.clone(),
205+
)
206+
}
207+
#[cfg(feature = "gpu")]
208+
InternalServerKey::Cuda(_) => {
209+
panic!("compress() on FheInt is not supported on GPU, use a CompressedCiphertextList instead");
210+
}
211+
#[cfg(feature = "hpu")]
212+
InternalServerKey::Hpu(_) => {
213+
panic!("compress() on FheInt is not supported on HPU devices");
214+
}
215+
})
192216
}
193217
}

tfhe/src/high_level_api/integers/unsigned/compressed.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ use crate::backward_compatibility::integers::{
55
};
66
use crate::conformance::ParameterSetConformant;
77
use crate::core_crypto::prelude::UnsignedNumeric;
8-
use crate::high_level_api::global_state::with_cpu_internal_keys;
98
use crate::high_level_api::integers::unsigned::base::{
109
FheUint, FheUintConformanceParams, FheUintId,
1110
};
11+
use crate::high_level_api::keys::InternalServerKey;
1212
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
1313
use crate::high_level_api::traits::{FheTryEncrypt, Tagged};
14-
use crate::high_level_api::ClientKey;
14+
use crate::high_level_api::{global_state, ClientKey};
1515
use crate::integer::block_decomposition::DecomposableInto;
1616
use crate::integer::ciphertext::{
1717
CompressedModulusSwitchedRadixCiphertext,
@@ -108,7 +108,19 @@ where
108108
let inner = match &self.ciphertext {
109109
CompressedRadixCiphertext::Seeded(ct) => ct.decompress(),
110110
CompressedRadixCiphertext::ModulusSwitched(ct) => {
111-
with_cpu_internal_keys(|sk| sk.pbs_key().decompress_parallelized(ct))
111+
global_state::with_internal_keys(|keys| match keys {
112+
InternalServerKey::Cpu(cpu_key) => {
113+
cpu_key.pbs_key().decompress_parallelized(ct)
114+
}
115+
#[cfg(feature = "gpu")]
116+
InternalServerKey::Cuda(_) => {
117+
panic!("decompress() on FheUint is not supported on GPU, use a CompressedCiphertextList instead");
118+
}
119+
#[cfg(feature = "hpu")]
120+
InternalServerKey::Hpu(_) => {
121+
panic!("decompress() on FheUint is not supported on HPU devices");
122+
}
123+
})
112124
}
113125
};
114126

@@ -179,11 +191,24 @@ where
179191
Id: FheUintId,
180192
{
181193
pub fn compress(&self) -> CompressedFheUint<Id> {
182-
let ciphertext = CompressedRadixCiphertext::ModulusSwitched(with_cpu_internal_keys(|sk| {
183-
sk.pbs_key()
184-
.switch_modulus_and_compress_parallelized(&self.ciphertext.on_cpu())
185-
}));
186-
CompressedFheUint::new(ciphertext, self.tag.clone())
194+
global_state::with_internal_keys(|keys| match keys {
195+
InternalServerKey::Cpu(cpu_key) => {
196+
let ciphertext = CompressedRadixCiphertext::ModulusSwitched(
197+
cpu_key
198+
.pbs_key()
199+
.switch_modulus_and_compress_parallelized(&self.ciphertext.on_cpu()),
200+
);
201+
CompressedFheUint::new(ciphertext, self.tag.clone())
202+
}
203+
#[cfg(feature = "gpu")]
204+
InternalServerKey::Cuda(_) => {
205+
panic!("compress() on FheUint is not supported on GPU, use a CompressedCiphertextList instead");
206+
}
207+
#[cfg(feature = "hpu")]
208+
InternalServerKey::Hpu(_) => {
209+
panic!("compress() on FheUint is not supported on HPU devices");
210+
}
211+
})
187212
}
188213
}
189214

0 commit comments

Comments
 (0)