Skip to content

Commit 3120b40

Browse files
authored
Update ram_init_final.rs
1 parent f36c2a3 commit 3120b40

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

prover/src/extensions/ram_init_final.rs

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,20 @@ impl FrameworkEvalExt for RamInitFinalEval {
153153
}
154154

155155
impl RamInitFinalEval {
156+
/// Converts WORD_SIZE bytes into a pair of 16-bit values (low, high)
157+
/// where low = bytes[0] + bytes[1] * 256
158+
/// and high = bytes[WORD_SIZE_HALVED] + bytes[WORD_SIZE_HALVED + 1] * 256
159+
fn bytes_to_word_halves<E: EvalAtRow>(bytes: &[E::F]) -> (E::F, E::F) {
160+
assert_eq!(bytes.len(), WORD_SIZE, "Expected {} bytes", WORD_SIZE);
161+
162+
let low = bytes[0].clone()
163+
+ bytes[1].clone() * E::F::from((1 << 8).into());
164+
let high = bytes[WORD_SIZE_HALVED].clone()
165+
+ bytes[WORD_SIZE_HALVED + 1].clone() * E::F::from((1 << 8).into());
166+
167+
(low, high)
168+
}
169+
156170
fn constrain_add_initial_values<E: EvalAtRow>(
157171
&self,
158172
eval: &mut E,
@@ -161,12 +175,9 @@ impl RamInitFinalEval {
161175
preprocessed_init_value: E::F,
162176
ram_init_final_flag: E::F,
163177
) {
178+
// Build tuple: [addr_low, addr_high, init_value, counter_zeros...]
164179
let mut tuple = vec![];
165-
// Build the tuple from the RAM address bytes.
166-
let addr_low = ram_init_final_addr[0].clone()
167-
+ ram_init_final_addr[1].clone() * E::F::from((1 << 8).into());
168-
let addr_high = ram_init_final_addr[2].clone()
169-
+ ram_init_final_addr[3].clone() * E::F::from((1 << 8).into());
180+
let (addr_low, addr_high) = Self::bytes_to_word_halves(ram_init_final_addr);
170181
tuple.push(addr_low);
171182
tuple.push(addr_high);
172183
// Add the product of preprocessed init flag and value.
@@ -192,19 +203,13 @@ impl RamInitFinalEval {
192203
ram_init_final_flag: E::F,
193204
) {
194205
let mut tuple = vec![];
195-
let addr_low = ram_init_final_addr[0].clone()
196-
+ ram_init_final_addr[1].clone() * E::F::from((1 << 8).into());
197-
let addr_high = ram_init_final_addr[2].clone()
198-
+ ram_init_final_addr[3].clone() * E::F::from((1 << 8).into());
206+
let (addr_low, addr_high) = Self::bytes_to_word_halves(ram_init_final_addr);
199207
tuple.push(addr_low);
200208
tuple.push(addr_high);
201209

202210
tuple.push(ram_final_value);
203211

204-
let counter_low = ram_final_counter[0].clone()
205-
+ ram_final_counter[1].clone() * E::F::from((1 << 8).into());
206-
let counter_high = ram_final_counter[2].clone()
207-
+ ram_final_counter[3].clone() * E::F::from((1 << 8).into());
212+
let (counter_low, counter_high) = Self::bytes_to_word_halves(ram_final_counter);
208213
tuple.push(counter_low);
209214
tuple.push(counter_high);
210215

@@ -344,6 +349,18 @@ impl BuiltInExtension for RamInitFinal {
344349
}
345350

346351
impl RamInitFinal {
352+
/// Packed version for SimdBackend: converts WORD_SIZE bytes from columns at given row
353+
/// into a pair of 16-bit values (low, high)
354+
fn bytes_to_word_halves_packed(byte_cols: &[BaseColumn], vec_row: usize) -> (PackedBaseField, PackedBaseField) {
355+
assert_eq!(byte_cols.len(), WORD_SIZE, "Expected {} byte columns", WORD_SIZE);
356+
357+
let shift = PackedBaseField::broadcast((1 << 8).into());
358+
let low = byte_cols[0].data[vec_row] + byte_cols[1].data[vec_row] * shift;
359+
let high = byte_cols[WORD_SIZE_HALVED].data[vec_row]
360+
+ byte_cols[WORD_SIZE_HALVED + 1].data[vec_row] * shift;
361+
362+
(low, high)
363+
}
347364
fn preprocessed_columns(log_size: u32, program_trace_ref: ProgramTraceRef) -> Vec<BaseColumn> {
348365
let total_len = program_trace_ref.init_memory.len()
349366
+ program_trace_ref.exit_code.len()
@@ -542,12 +559,7 @@ impl RamInitFinal {
542559
// Add (address, value, 0)
543560
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
544561
let mut tuple = vec![];
545-
let addr_low = ram_init_final_addr[0].data[vec_row]
546-
+ ram_init_final_addr[1].data[vec_row]
547-
* PackedBaseField::broadcast((1 << 8).into());
548-
let addr_high = ram_init_final_addr[2].data[vec_row]
549-
+ ram_init_final_addr[3].data[vec_row]
550-
* PackedBaseField::broadcast((1 << 8).into());
562+
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
551563
tuple.push(addr_low);
552564
tuple.push(addr_high);
553565

@@ -586,21 +598,13 @@ impl RamInitFinal {
586598
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
587599
let mut tuple = vec![];
588600

589-
let addr_low = ram_init_final_addr[0].data[vec_row]
590-
+ ram_init_final_addr[1].data[vec_row]
591-
* PackedBaseField::broadcast((1 << 8).into());
592-
let addr_high = ram_init_final_addr[2].data[vec_row]
593-
+ ram_init_final_addr[3].data[vec_row]
594-
* PackedBaseField::broadcast((1 << 8).into());
601+
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
595602
tuple.push(addr_low);
596603
tuple.push(addr_high);
597604

598605
tuple.push(ram_final_value.data[vec_row]);
599606

600-
let counter_low = ram_final_counter[0].data[vec_row]
601-
+ ram_final_counter[1].data[vec_row] * PackedBaseField::broadcast((1 << 8).into());
602-
let counter_high = ram_final_counter[2].data[vec_row]
603-
+ ram_final_counter[3].data[vec_row] * PackedBaseField::broadcast((1 << 8).into());
607+
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
604608
tuple.push(counter_low);
605609
tuple.push(counter_high);
606610

0 commit comments

Comments
 (0)