diff --git a/Cargo.lock b/Cargo.lock index f738247ef..e7ca6fb90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -522,6 +522,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4445909572dbd556c457c849c4ca58623d84b27c8fff1e74b0b4227d8b90d17b" +[[package]] +name = "elf" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55dd888a213fc57e957abf2aa305ee3e8a28dbe05687a251f33b637cd46b0070" + [[package]] name = "elf_loader" version = "0.12.0" @@ -531,7 +537,7 @@ dependencies = [ "bitflags 2.9.4", "cfg-if", "delegate", - "elf", + "elf 0.7.4", ] [[package]] @@ -863,7 +869,7 @@ dependencies = [ "cms", "const-oid", "digest", - "elf", + "elf 0.8.0", "hashbrown 0.15.5", "libc", "litebox", @@ -989,7 +995,7 @@ dependencies = [ "arrayvec", "bitflags 2.9.4", "bitvec", - "elf", + "elf 0.7.4", "elf_loader", "litebox", "litebox_common_linux", @@ -1010,7 +1016,7 @@ dependencies = [ "bitflags 2.9.4", "cfg-if", "ctr", - "elf", + "elf 0.7.4", "elf_loader", "hashbrown 0.15.5", "litebox", diff --git a/litebox_platform_lvbs/Cargo.toml b/litebox_platform_lvbs/Cargo.toml index 2c5e4090c..6278723eb 100644 --- a/litebox_platform_lvbs/Cargo.toml +++ b/litebox_platform_lvbs/Cargo.toml @@ -25,7 +25,7 @@ num_enum = { version = "0.7.3", default-features = false } once_cell = { version = "1.20.2", default-features = false, features = ["alloc", "race"] } modular-bitfield = { version = "0.12.0", default-features = false } hashbrown = "0.15.2" -elf = { version = "0.7.4", default-features = false } +elf = { version = "0.8.0", default-features = false } cms = { version = "0.2.3", default-features = false, features = ["alloc"] } rsa = { version = "0.9.8", default-features = false } sha1 = { version = "0.10.6", default-features = false, features = ["force-soft"] } diff --git a/litebox_platform_lvbs/src/mshv/heki.rs b/litebox_platform_lvbs/src/mshv/heki.rs index cd12e538c..989c0fe4c 100644 --- a/litebox_platform_lvbs/src/mshv/heki.rs +++ b/litebox_platform_lvbs/src/mshv/heki.rs @@ -62,8 +62,7 @@ pub enum HekiKexecType { #[default] Unknown = 0xffff_ffff_ffff_ffff, } - -#[derive(Clone, Copy, Default, Debug, TryFromPrimitive, PartialEq)] +#[derive(Clone, Copy, Default, Debug, TryFromPrimitive, PartialEq, Hash, Eq)] #[repr(u64)] pub enum ModMemType { Text = 0, @@ -75,10 +74,59 @@ pub enum ModMemType { InitRoData = 6, ElfBuffer = 7, Patch = 8, + Syms = 9, + GplSyms = 10, #[default] Unknown = 0xffff_ffff_ffff_ffff, } +impl From for ModMemType { + fn from(i: usize) -> Self { + match i { + 0 => ModMemType::Text, + 1 => ModMemType::Data, + 2 => ModMemType::RoData, + 3 => ModMemType::RoAfterInit, + 4 => ModMemType::InitText, + 5 => ModMemType::InitData, + 6 => ModMemType::InitRoData, + _ => ModMemType::Unknown, + } + } +} + +impl From for usize { + fn from(m: ModMemType) -> Self{ + match m { + ModMemType::Text => 0, + ModMemType::Data => 1, + ModMemType::RoData => 2, + ModMemType::RoAfterInit=>3, + ModMemType::InitText=>4, + ModMemType::InitData=>5, + ModMemType::InitRoData=>6, + _=>0xffff_ffff_ffff_ffff + } + } +} +/* +impl Into for usize { + + fn into(self) -> ModMemType { + match self { + ModMemType::Text => 0, + ModMemType::Data => 1, + ModMemType::RoData => 2, + ModMemType::RoAfterInit=>3, + ModMemType::InitText=>4, + ModMemType::InitData=>5, + ModMemType::InitRoData=>6, + _=>0xffff_ffff_ffff_ffff + } + } +} + */ + pub(crate) fn mod_mem_type_to_mem_attr(mod_mem_type: ModMemType) -> MemAttr { let mut mem_attr = MemAttr::empty(); diff --git a/litebox_platform_lvbs/src/mshv/kmod.rs b/litebox_platform_lvbs/src/mshv/kmod.rs new file mode 100644 index 000000000..569c90475 --- /dev/null +++ b/litebox_platform_lvbs/src/mshv/kmod.rs @@ -0,0 +1,600 @@ +// TODO: file header comments + +use crate::{ + mshv::{ + heki::ModMemType, + vsm::{ModuleMemory, ModuleMemoryMetadata}, + vtl1_mem_layout::PAGE_SIZE, + }, + serial_print, serial_println, +}; +use aligned_vec::{AVec, ConstAlign, avec}; +use alloc::{format, string::String, vec, vec::Vec}; +use core::{mem, ops::Range}; +use elf::{ + ElfBytes, abi as ElfAbi, + endian::AnyEndian, + section::{SectionHeader, SectionHeaderTable}, + string_table::StringTable, + symbol::{Symbol, SymbolTable}, +}; +use thiserror::Error as ThisError; +use x86_64::VirtAddr; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("Error: {0}")] + Generic(String), + + #[error("Not Found: {0}")] + NotFound(String), + + #[error("Elf type is not supported")] + UnsupportedElf, + + #[error("Unsuppoted: {0}")] + Unsupported(String), + + #[error("Elf section {0} missing")] + MissingSection(String), + + #[error("Arithmetic error: {0}")] + Arithmetic(String), + + #[error("Bad symbol name")] + BadSymbolName, + + #[error("ElfParse: {0}")] + Parser(#[from] elf::ParseError), +} + +//#[derive(Default)] +//struct ModMemBuf { +// section: Vec<(usize, Range)>, +// len: usize, //I dont need to store this, alloc buf as end and keep size there +// buf: Vec, +//} + +#[derive(Copy, Clone)] +struct ModMemMask { + mem_type: ModMemType, + allow_mask: u64, + forbid_mask: u64, + init: bool, +} + +#[derive(Clone)] +struct ModSection { + sh_index: usize, + shdr: SectionHeader, //TODO: do i need this, I can hget it by api +} + +#[derive(Clone)] +struct ModMem { + //mem_type: ModMemType, + section_map: Vec<(ModSection, Range)>, + len: usize, // Buf size grows as we process section headers. + // Keep track of length seprately so we can call + // expensive buf.resize() operation once at the + // end of processing headers + buf: AVec>, + vtl0_va: VirtAddr, +} + +impl Default for ModMem { + fn default() -> Self { + ModMem { + section_map: Vec::new(), + len: 0, + buf: avec![[PAGE_SIZE] | 0u8; 0], + vtl0_va: VirtAddr::new(0), + } + } +} + +impl ModMem { + fn new(vtl0_va: VirtAddr) -> Self { + ModMem { + section_map: Vec::new(), + len: 0, + buf: avec![[PAGE_SIZE] | 0u8; 0], + vtl0_va, + } + } + + fn get_section(&self, sh_index: usize) -> Option<(VirtAddr, &[u8])> { + for (section, range) in &self.section_map { + if sh_index == section.sh_index { + return Some((self.vtl0_va + range.start as u64, &self.buf[range.clone()])); + } + } + None + } +} + +fn get_section_buf(mem_map: &mut [ModMem], sh_index: usize) -> Option<(&mut [u8], VirtAddr)> { + for mem in mem_map { + for (section, range) in &mem.section_map { + if sh_index == section.sh_index { + return Some(( + &mut mem.buf[range.clone()], + mem.vtl0_va + range.start as u64, + )); + } + } + } + None +} + +fn get_section_va(mem_map: &Vec, sh_index: usize) -> Option { + for mem in mem_map { + let va = mem.vtl0_va; + for (section, range) in &mem.section_map { + if sh_index == section.sh_index { + return Some(va + range.start as u64); + } + } + } + None +} + +fn section_is_alloc(mem_map: &Vec, sh_index: usize) -> bool { + for mem in mem_map { + for (section, _) in &mem.section_map { + if sh_index == section.sh_index { + return true; + } + } + } + false +} + +struct ElfParams<'a> { + elf: &'a ElfBytes<'a, AnyEndian>, + shdrs: &'a SectionHeaderTable<'a, AnyEndian>, + shdr_strtab: &'a StringTable<'a>, + sym_hdr: &'a SymbolTable<'a, AnyEndian>, + sym_strtab: &'a StringTable<'a>, +} + +pub fn valid_elf( + bytes: &[u8], + module_in_memory: &ModuleMemory, + mod_mem_metadata: &ModuleMemoryMetadata, +) -> Result<(), Error> { + let elf = ElfBytes::::minimal_parse(bytes)?; + + let (Some(shdrs), Some(shdr_strtab)) = elf.section_headers_with_strtab()? else { + return Err(Error::MissingSection(String::from("header table"))); + }; + + let Some((sym_hdr, sym_strtab)) = elf.symbol_table()? else { + return Err(Error::MissingSection(String::from("symbol table"))); + }; + + let elf_params = ElfParams { + elf: &elf, + shdrs: &shdrs, + shdr_strtab: &shdr_strtab, + sym_hdr: &sym_hdr, + sym_strtab: &sym_strtab, + }; + + // Check for Linux-specific elf attributes + check_linux_elf(&elf_params)?; + + // Categorize section headers under module memory types + // using linux-specific algo + let mut mem_map = layout_elf(&elf_params, mod_mem_metadata)?; + + relocate_elf(&elf_params, &mut mem_map)?; + + finalize_elf(&elf_params, &mut mem_map); + + let elf_text = mem_map.get(usize::from(ModMemType::Text)).unwrap(); + let mut mem_text_buf = vec![0u8; module_in_memory.text.len()]; + let mut mem_text_buf = avec![[PAGE_SIZE] | 0u8; module_in_memory.text.len()]; + module_in_memory + .text + .read_bytes(module_in_memory.text.start().unwrap(), &mut mem_text_buf) + .map_err(|_| Error::MissingSection(String::from("no text to compare")))?; + + serial_println!( + "Elf text length: {} mem_tex_length:{}", + elf_text.buf.len(), + module_in_memory.text.len() + ); + + if elf_text.buf == mem_text_buf { + serial_println!("Text matched!!!"); + } else { + serial_println!("Text did NOT match"); + } + + Ok(()) +} + +fn check_linux_elf(elf_params: &ElfParams) -> Result<(), Error> { + let elf = elf_params.elf; + if elf.ehdr.class != elf::file::Class::ELF64 + || elf.ehdr.e_type != elf::abi::ET_REL + || elf.ehdr.e_machine != elf::abi::EM_X86_64 + { + return Err(Error::UnsupportedElf); + } + + let Some(_shdr_modinfo) = elf.section_header_by_name(".modinfo")? else { + return Err(Error::MissingSection(String::from(".modinfo"))); + }; + let Some(shdr_gnu) = elf.section_header_by_name(".gnu.linkonce.this_module")? else { + return Err(Error::MissingSection(String::from( + ".gnu.linkonce.this_module", + ))); + }; + if shdr_gnu.sh_flags & u64::from(elf::abi::SHF_ALLOC) == 0 { + return Err(Error::MissingSection(String::from( + "gnu.linkonce.this_module", + ))); + } + // TODO: Maybe validate against struct module (why? hellps with jump labels and other optional struct) + // TODO: It would pay to also check crc, or find some other way to veriy stuct module is as expected (correct options set) + // For hints see early_mod_check, check_modstruct_version check_modinfo(versionmagic) + // TODO: Get flag MODULE_INIT_IGNORE_MODVERSIONS to determine if __versions should be ignored */ + Ok(()) +} + +fn layout_elf( + elf_params: &ElfParams, + mem_metadata: &ModuleMemoryMetadata, +) -> Result, Error> { + //let mut mem_map = vec![ModMem::default(); ModMemType::InitRoData as usize]; + + let mut mem_map: Vec = (usize::from(ModMemType::Text) + ..usize::from(ModMemType::InitRoData)) + .map(|m| { + let mem_type = ModMemType::from(m); + serial_println!("layout_elf: init: mod_mem_type: {:?} m:{}", mem_type, m); + ModMem::new( + mem_metadata + .get_mem_type_va(mem_type) + .or(Some(VirtAddr::zero())) + .unwrap(), + ) + }) + .collect(); + + for (sh_index, shdr) in elf_params.shdrs.iter().enumerate() { + if sh_index == 0 { + continue; + } + + let sh_name = elf_params.shdr_strtab.get(shdr.sh_name as usize)?; + if MOD_SECTION_SKIP.contains(&sh_name) { + serial_println!("kmod:layout_elf:{sh_name} skipped"); + continue; + } + + let mut sh_flags = shdr.sh_flags; + if MOD_SECTION_RO_AFTER_INIT.contains(&sh_name) { + sh_flags |= SHF_RO_AFTER_INIT as u64; + } + + let Some(mask) = MOD_MEM_MASK.iter().find(|&mask| { + sh_flags & mask.allow_mask == mask.allow_mask + && sh_flags & mask.forbid_mask == 0 + && sh_name.contains(".init") == mask.init + }) else { + continue; + }; + + let mem = mem_map.get_mut(mask.mem_type as usize).unwrap(); //TODO: Should we just mek extended enums like elf, patch extended? + + let buf_start = mem.len.next_multiple_of(shdr.sh_addralign as usize); + let buf_end = buf_start + .checked_add(shdr.sh_size as usize) + .ok_or(Error::Arithmetic(format!( + "Buf size too large for section {}: start:{} size:{}", + sh_index, buf_start, shdr.sh_size + )))?; + + mem.section_map + .push((ModSection { sh_index, shdr }, buf_start..buf_end)); + mem.len = buf_end; + + serial_println!( + "kmod:layout_elf:{}: index:{} mem_type:{:?} va:{:?} size:{}/{} flag:{}", + sh_name, + sh_index, + mask.mem_type, + mem.vtl0_va + buf_start as u64, + shdr.sh_size, + buf_end - buf_start, + shdr.sh_type + ); + } + + for mem in mem_map.iter_mut() { + mem.len = mem.len.next_multiple_of(PAGE_SIZE); + mem.buf.resize(mem.len, 0); + for (section, range) in &mem.section_map { + let (data, _) = elf_params.elf.section_data(§ion.shdr)?; + if section.shdr.sh_type == ElfAbi::SHT_NOBITS { + mem.buf[range.clone()].fill(0); + } else { + mem.buf[range.clone()].copy_from_slice(data); + } + } + } + + Ok(mem_map) +} + +fn get_symbol_value(elf_params: &ElfParams, mem_map: &Vec, sym: &Symbol) -> Option { + match sym.st_shndx { + shn_rsvd @ ElfAbi::SHN_LORESERVE..=ElfAbi::SHN_HIRESERVE => match shn_rsvd { + SHN_LIVEPATCH => None, + ElfAbi::SHN_ABS => Some(sym.st_value), + ElfAbi::SHN_COMMON => None, + _ => None, + }, + ElfAbi::SHN_UNDEF => resolve_symbol(elf_params, &sym), + sh_index => { + if let Some(vtl0_va) = get_section_va(mem_map, sh_index as usize) { + if vtl0_va != VirtAddr::zero() { + Some(vtl0_va.as_u64() + sym.st_value) + } else { + None + } + } else { + None + } + } + } +} + +fn relocate_elf(elf_params: &ElfParams, mem_map: &mut Vec) -> Result<(), Error> { + for (sh_index, shdr) in elf_params.shdrs.iter().enumerate() { + if shdr.sh_type != ElfAbi::SHT_RELA { + continue; + } + + if !section_is_alloc(mem_map, shdr.sh_info as usize) { + serial_println!( + "Skipping {}", + elf_params.shdr_strtab.get(shdr.sh_name as usize)? + ); + continue; + } + + serial_println!( + "==Relocate from {}", + elf_params.shdr_strtab.get(shdr.sh_name as usize)? + ); + + for rela in elf_params.elf.section_data_as_relas(&shdr)? { + let Ok(sym) = elf_params.sym_hdr.get(rela.r_sym as usize) else { + continue; + }; + + //serial_println!("Find sym:{:?}", sym, ); + let Some(st_value) = get_symbol_value(elf_params, &*mem_map, &sym) else { + return Err(Error::NotFound(format!( + "Value for symbol sh_index:{} in rela sh_index {}", + sym.st_shndx, sh_index + ))); + }; + + let sym_value = (st_value as i64 + rela.r_addend) as u64; + /*serial_println!("Final sym_value: {:#x} rela_addend:{:#x}",sym_value, rela.r_addend);*/ + + //add offset to dst_buf + /* + crate::serial_print!( + "{}: Sym:{} type:{} offset:{:#x} addend:{:#x} ({}) st_value:{:#x} val_start:{:#x} ", + index, + rela.r_sym, + rela.r_type, + rela.r_offset, + rela.r_addend, + rela.r_addend, + sym.st_value, + sym_value + );*/ + // Mutable dst_buf must be evaluated here to keep satisfy borrow checker + let Some((dst_buf, vtl0_va)) = get_section_buf(mem_map, shdr.sh_info as usize) else { + continue; //TODO: clean up, see how to move this out ot the loop and keep nborrow checkwre happy + }; + + if vtl0_va == VirtAddr::zero() { + continue; + } + + let src: &[u8] = match rela.r_type { + ElfAbi::R_X86_64_NONE => continue, + ElfAbi::R_X86_64_64 => &sym_value.to_ne_bytes(), + ElfAbi::R_X86_64_32 => &u32::try_from(sym_value).unwrap().to_ne_bytes(), + ElfAbi::R_X86_64_32S => &i32::try_from(sym_value as i64).unwrap().to_ne_bytes(), + ElfAbi::R_X86_64_PC32 | ElfAbi::R_X86_64_PLT32 => { + let va = vtl0_va.as_u64() + rela.r_offset; + let sym_value = sym_value.wrapping_sub(va); + &i32::try_from(sym_value as i64).unwrap().to_ne_bytes() + } + ElfAbi::R_X86_64_PC64 => { + let va = vtl0_va.as_u64() + rela.r_offset; + let sym_value = sym_value.wrapping_sub(va); + &(sym_value as i64).to_ne_bytes() + } + _ => panic!("Bad rela"), + }; + /* + if src.len() == 4 { + crate::serial_println!("val:{:#x}", u32::from_ne_bytes(src.try_into().unwrap())); + } else { + crate::serial_println!("val:{:#x}", u64::from_ne_bytes(src.try_into().unwrap())); + } + */ + + let dst_offset = rela.r_offset as usize; + let dst = &mut dst_buf[dst_offset..dst_offset + src.len()]; + dst.copy_from_slice(&src); + } + } + Ok(()) +} + +#[repr(C)] +struct ParavirtPatchSite { + instr: *mut u8, + typ: u8, + len: u8, +} + +fn resolve_symbol(elf_params: &ElfParams, sym: &Symbol) -> Option { + // Get symbol str name + if sym.st_name == 0 { + return None; + } + let sym_name = match elf_params.sym_strtab.get(sym.st_name as usize) { + Ok(sym_name) => sym_name, + Err(_) => { + serial_println!("Symbol not found"); //TODO clean this up + return None; + } + }; + crate::platform_low().vtl0_kernel_info.find_symbol(sym_name) +} + +fn finalize_elf(elf_params: &ElfParams, mem_map: &mut Vec) -> Result<(), Error> { + for (sh_index, shdr) in elf_params.shdrs.iter().enumerate() { + let sh_name = elf_params.shdr_strtab.get(shdr.sh_name as usize)?; + if sh_name.eq(".parainstructions") { + serial_println!("Applying {}", sh_name); + + //let (a, b) = mem_map.split_at_mut(usize::from(ModMemType::Text)); + + let (data, _) = get_section_buf(mem_map, sh_index).unwrap(); + let mut d = avec![[{mem::align_of::()}] | 0u8; data.len()]; + d.copy_from_slice(data); + + let text = mem_map.get(usize::from(ModMemType::Text)).unwrap(); + serial_println!("SAMPLE: :{:#x} {:#x} {:#x} {:#x}", d[0], d[1], d[2], d[3],); + apply_paravirt(mem_map, &d, &text.buf); + } + } + + Ok(()) +} + +// I need a way to find a set of bytes and copy i + +fn get_mod_mem_byte_slice(mem_map: &Vec, va: VirtAddr, len: usize) -> Option<&[u8]> { + for mem in mem_map { + let end = VirtAddr::new(va.as_u64() + len as u64); + let mem_end = VirtAddr::new(mem.vtl0_va.as_u64() + mem.buf.len() as u64); + if va >= mem.vtl0_va && end <= mem_end { + let offset = (va - mem.vtl0_va) as usize; + return Some(&mem.buf[offset..offset + len]); + } + } + None +} +fn apply_paravirt(mem_map: &Vec, pv_bytes: &[u8], _text_bytes: &[u8]) { + if pv_bytes.len() % mem::size_of::() != 0 { + return; + } + let pv_count = pv_bytes.len() / mem::size_of::(); + let para: *const ParavirtPatchSite = pv_bytes.as_ptr().cast::(); + if !para.is_aligned() { + return; + } + + for pv_index in 0..pv_count { + let mut pv_addr; + let pv_type; + let pv_len; + + unsafe { + let pv = para.offset(pv_index as isize); + serial_print!( + " Apply paravirt: {:p}:{:#?} type:{} len:{}", + pv, + (*pv).instr, + (*pv).typ, + (*pv).len + ); + pv_addr = (*pv).instr; + pv_type = (*pv).typ; + pv_len = (*pv).len; + } + let Some(insn) = + get_mod_mem_byte_slice(&mem_map, VirtAddr::from_ptr(pv_addr), pv_len as usize) + else { + continue; + }; + serial_println!(" insn:{:?}", insn); + } +} + +const SHF_RO_AFTER_INIT: u32 = 0x00200000; +const SHN_LIVEPATCH: u16 = 0xff20; + +static MOD_MEM_MASK: [ModMemMask; 9] = [ + ModMemMask { + mem_type: ModMemType::Text, + allow_mask: (ElfAbi::SHF_EXECINSTR | ElfAbi::SHF_ALLOC) as u64, + forbid_mask: 0, + init: false, + }, + ModMemMask { + mem_type: ModMemType::RoData, + allow_mask: ElfAbi::SHF_ALLOC as u64, + forbid_mask: ElfAbi::SHF_WRITE as u64, + init: false, + }, + ModMemMask { + mem_type: ModMemType::RoAfterInit, + allow_mask: (SHF_RO_AFTER_INIT | ElfAbi::SHF_ALLOC) as u64, + forbid_mask: 0, + init: false, + }, + ModMemMask { + mem_type: ModMemType::Data, + allow_mask: (ElfAbi::SHF_WRITE | ElfAbi::SHF_ALLOC) as u64, + forbid_mask: 0, + init: false, + }, + ModMemMask { + mem_type: ModMemType::Data, + allow_mask: ElfAbi::SHF_ALLOC as u64, + forbid_mask: 0, + init: false, + }, + ModMemMask { + mem_type: ModMemType::InitText, + allow_mask: (ElfAbi::SHF_EXECINSTR | ElfAbi::SHF_ALLOC) as u64, + forbid_mask: 0, + init: true, + }, + ModMemMask { + mem_type: ModMemType::InitRoData, + allow_mask: ElfAbi::SHF_ALLOC as u64, + forbid_mask: ElfAbi::SHF_WRITE as u64, + init: true, + }, + ModMemMask { + mem_type: ModMemType::InitData, + allow_mask: (ElfAbi::SHF_WRITE | ElfAbi::SHF_ALLOC) as u64, + forbid_mask: 0, + init: true, + }, + ModMemMask { + mem_type: ModMemType::InitData, + allow_mask: ElfAbi::SHF_ALLOC as u64, + forbid_mask: 0, + init: true, + }, +]; + +static MOD_SECTION_SKIP: [&str; 3] = [".modinfo", "__versions", ".data..percpu"]; + +static MOD_SECTION_RO_AFTER_INIT: [&str; 2] = ["__jump_table", ".data..ro_after_init"]; diff --git a/litebox_platform_lvbs/src/mshv/mod.rs b/litebox_platform_lvbs/src/mshv/mod.rs index cd55059ff..cc1f705b0 100644 --- a/litebox_platform_lvbs/src/mshv/mod.rs +++ b/litebox_platform_lvbs/src/mshv/mod.rs @@ -4,6 +4,7 @@ mod heki; pub mod hvcall; mod hvcall_mm; mod hvcall_vp; +mod kmod; mod mem_integrity; pub(crate) mod vsm; mod vsm_intercept; diff --git a/litebox_platform_lvbs/src/mshv/vsm.rs b/litebox_platform_lvbs/src/mshv/vsm.rs index 45c59dcc2..ff4c5758a 100644 --- a/litebox_platform_lvbs/src/mshv/vsm.rs +++ b/litebox_platform_lvbs/src/mshv/vsm.rs @@ -27,6 +27,7 @@ use crate::{ hvcall::HypervCallError, hvcall_mm::hv_modify_vtl_protection_mask, hvcall_vp::{hvcall_get_vp_vtl0_registers, hvcall_set_vp_registers, init_vtl_aps}, + kmod::valid_elf, mem_integrity::{ validate_kernel_module_against_elf, validate_text_patch, verify_kernel_module_signature, verify_kernel_pe_signature, @@ -556,6 +557,12 @@ pub fn mshv_vsm_validate_guest_module(pa: u64, nranges: u64, _flags: u64) -> Res return Err(Errno::EINVAL); } + if let Err(e) = valid_elf(&original_elf_data, &module_in_memory, &module_memory_metadata) { + serial_println!("VSM: kmod: Elf was not valid: {e}"); + } else { + serial_println!("VSM: kmod: Elf was valid"); + }; + // pre-computed patch data for a module if !patch_info_for_module.is_empty() { let mut patch_info_buf = vec![0u8; patch_info_for_module.len()]; @@ -569,6 +576,44 @@ pub fn mshv_vsm_validate_guest_module(pa: u64, nranges: u64, _flags: u64) -> Res .map_err(|_| Errno::EINVAL)?; } + // symbol data for module + if !module_in_memory.symbols.is_empty() { + serial_println!("Found symbols: size:{}", module_in_memory.symbols.len()); + let mut rodata_buf = avec![[{ core::mem::align_of::() }] | 0u8; module_in_memory.rodata.len()]; + module_in_memory + .rodata + .read_bytes(module_in_memory.rodata.start().unwrap(), &mut rodata_buf) + .map_err(|_| Errno::EINVAL)?; + + module_memory_metadata.symbols.build_from_container( + module_in_memory.symbols.range.start, + module_in_memory.symbols.range.end, + &module_in_memory.rodata, + &rodata_buf, + )?; + } + + if !module_in_memory.gpl_symbols.is_empty() { + serial_println!("Found symbols: size:{}", module_in_memory.gpl_symbols.len()); + let mut rodata_buf = avec![[{ core::mem::align_of::() }] | 0u8; module_in_memory.rodata.len()]; + module_in_memory + .rodata + .read_bytes(module_in_memory.rodata.start().unwrap(), &mut rodata_buf) + .map_err(|_| Errno::EINVAL)?; + + module_memory_metadata.symbols.build_from_container( + module_in_memory.gpl_symbols.range.start, + module_in_memory.gpl_symbols.range.end, + &module_in_memory.rodata, + &rodata_buf, + )?; + } + + //read one symbol, one gpl symbol as sample + if let Some(s) = module_memory_metadata.symbols.list_one() { + serial_println!("Module symbol test: {}", s); + } + // once a module is verified and validated, change the permission of its memory ranges based on their types for mod_mem_range in &module_memory_metadata { protect_physical_memory_range( @@ -597,6 +642,16 @@ fn prepare_data_for_module_validation( ) -> Result<(), Errno> { for heki_page in heki_pages { for heki_range in heki_page { + let va = heki_range.va; + let pa = heki_range.pa; + let epa = heki_range.epa; + serial_println!( + "mod:range: type:{:?} va:{:#x} pa:{:#x} epa:{:#x}", + heki_range.mod_mem_type(), + va, + pa, + epa + ); match heki_range.mod_mem_type() { ModMemType::Unknown => { serial_println!("VSM: Invalid module memory type"); @@ -614,10 +669,10 @@ fn prepare_data_for_module_validation( } _ => { // if input memory range's type is neither `Unknown` nor `ElfBuffer`, its addresses must be page-aligned - if !heki_range.is_aligned(Size4KiB::SIZE) { - serial_println!("VSM: input address must be page-aligned"); - return Err(Errno::EINVAL); - } + //if !heki_range.is_aligned(Size4KiB::SIZE) { + // serial_println!("VSM: input address must be page-aligned"); + // return Err(Errno::EINVAL); + //} module_in_memory .write_bytes_from_heki_range(heki_range) @@ -1174,6 +1229,18 @@ impl Vtl0KernelInfo { }) .or(None) } + + pub fn find_symbol(&self, sym_name: &str) -> Option { + match self.gpl_symbols.find(sym_name) { + None => match self.symbols.find(sym_name) { + None => { + self.module_memory_metadata.find_symbol(sym_name) + } + Some(value) => Some(value), + }, + Some(value) => Some(value), + } + } } /// Data structure for maintaining the memory ranges of each VTL0 kernel module and their types @@ -1183,8 +1250,10 @@ pub struct ModuleMemoryMetadataMap { } pub struct ModuleMemoryMetadata { - ranges: Vec, + pub ranges: Vec, // TODO: FEMI: Make this priv again and put in heki_mod or make this struct HekiMod patch_targets: Vec, + symbols: SymbolTable, + gpl_symbols: SymbolTable, } impl ModuleMemoryMetadata { @@ -1192,6 +1261,8 @@ impl ModuleMemoryMetadata { Self { ranges: Vec::new(), patch_targets: Vec::new(), + symbols: SymbolTable::new(), + gpl_symbols: SymbolTable::new(), } } @@ -1213,6 +1284,17 @@ impl ModuleMemoryMetadata { self.ranges.push(mem_range); } + pub(crate) fn get_mem_type_va(&self, mem_type: ModMemType) -> Option { + self.ranges.iter().find_map(|range| { + if mem_type == range.mod_mem_type { + //TODO: FEMI: We take the first range as truth. Fix with new heki_walk/etc + Some(range.virt_addr) + } else { + None + } + }) + } + #[inline] pub(crate) fn insert_patch_target(&mut self, patch_target: PhysAddr) { self.patch_targets.push(patch_target); @@ -1224,6 +1306,18 @@ impl ModuleMemoryMetadata { pub(crate) fn get_patch_targets(&self) -> &Vec { &self.patch_targets } + + pub(crate) fn find_symbol(&self, sym_name: &str) -> Option { + match self.gpl_symbols.find(sym_name) { + None => match self.symbols.find(sym_name) { + None => { + None + } + Some(value) => Some(value), + }, + Some(value) => Some(value), + } + } } impl Default for ModuleMemoryMetadata { @@ -1327,6 +1421,17 @@ impl ModuleMemoryMetadataMap { None } } + + pub fn find_symbol(&self, sym_name: &str) -> Option { + let map = self.inner.lock(); + for (_, data) in map.iter() { + if let Some(value) = data.find_symbol(sym_name) { + serial_println!("Found symbol {} in module", sym_name); + return Some(value); + } + } + None + } } impl Default for ModuleMemoryMetadataMap { @@ -1393,17 +1498,25 @@ fn protect_physical_memory_range( /// Data structure for maintaining the memory content of a kernel module by its sections. Currently, it only maintains /// certain sections like `.text` and `.init.text` which are needed for module validation. pub struct ModuleMemory { - text: MemoryContainer, + pub text: MemoryContainer, + data: MemoryContainer, + rodata: MemoryContainer, init_text: MemoryContainer, init_rodata: MemoryContainer, + symbols: MemoryContainer, + gpl_symbols: MemoryContainer, } impl ModuleMemory { pub fn new() -> Self { Self { text: MemoryContainer::new(), + data: MemoryContainer::new(), + rodata: MemoryContainer::new(), init_text: MemoryContainer::new(), init_rodata: MemoryContainer::new(), + symbols: MemoryContainer::new(), + gpl_symbols: MemoryContainer::new(), } } @@ -1446,6 +1559,8 @@ impl ModuleMemory { ) -> Result<(), MemoryContainerError> { match mod_mem_type { ModMemType::Text => self.text.write_vtl0_phys_bytes(addr, phys_start, phys_end), + ModMemType::Data => self.data.write_vtl0_phys_bytes(addr, phys_start, phys_end), + ModMemType::RoData => self.rodata.write_vtl0_phys_bytes(addr, phys_start, phys_end), ModMemType::InitText => self .init_text .write_vtl0_phys_bytes(addr, phys_start, phys_end), @@ -1454,10 +1569,14 @@ impl ModuleMemory { .write_vtl0_phys_bytes(addr, phys_start, phys_end), ModMemType::ElfBuffer | ModMemType::Patch - | ModMemType::Data - | ModMemType::RoData | ModMemType::RoAfterInit | ModMemType::InitData => Ok(()), // we don't validate other memory types for now + ModMemType::Syms => self + .symbols + .write_vtl0_phys_bytes(addr, phys_start, phys_end), + ModMemType::GplSyms => self + .gpl_symbols + .write_vtl0_phys_bytes(addr, phys_start, phys_end), ModMemType::Unknown => Err(MemoryContainerError::InvalidType), } } @@ -1915,19 +2034,23 @@ impl Symbol { start: VirtAddr, bytes: &[u8], ) -> Result<(String, Self), Errno> { + //serial_println!("from_bytes: start:{:#x} size:{}", start, kinfo_start); let kinfo_bytes = &bytes[kinfo_start..]; let ksym = HekiKernelSymbol::from_bytes(kinfo_bytes)?; + //serial_println!("from_bytes: ksym {:#x} {:#x} ", ksym.name_offset, ksym.value_offset); let value_addr = start + mem::offset_of!(HekiKernelSymbol, value_offset) as u64; let value = value_addr .as_u64() .wrapping_add_signed(i64::from(ksym.value_offset)); + //serial_println!("from_bytes: value:{:#}", value); let name_offset = kinfo_start + mem::offset_of!(HekiKernelSymbol, name_offset) + usize::try_from(ksym.name_offset).map_err(|_| Errno::EINVAL)?; if name_offset >= bytes.len() { + serial_println!("name_offset:{} bytes length {}", name_offset, bytes.len()); return Err(Errno::EINVAL); } let name_len = bytes[name_offset..] @@ -1970,6 +2093,13 @@ impl SymbolTable { mem: &MemoryContainer, buf: &[u8], ) -> Result { + serial_println!( + "build_from_container: start:{:#x} end:{:#x} range:{:#x}..{:#x}", + start, + end, + mem.range.start, + mem.range.end + ); if start < mem.range.start || end > mem.range.end { serial_println!("VSM: Symbol table data not found"); return Err(Errno::EINVAL); @@ -1986,12 +2116,33 @@ impl SymbolTable { let mut inner = self.inner.write(); inner.reserve(ksym_count); + let mut rate_limit = 0; for _ in 0..ksym_count { let (name, sym) = Symbol::from_bytes(kinfo_offset, kinfo_addr, buf).unwrap(); + if rate_limit < 20 { + crate::serial_println!(" sym:{} value:{}", name, sym._value); + rate_limit += 1; + } inner.insert(name, sym); kinfo_offset += HekiKernelSymbol::KSYM_LEN; kinfo_addr += HekiKernelSymbol::KSYM_LEN as u64; } Ok(0) } + + pub fn find(&self, sym_name: &str) -> Option { + let inner = self.inner.write(); + let Some(sym) = inner.get(sym_name) else { + return None; + }; + Some(sym._value) + } + + pub fn list_one(&self) -> Option { + let inner = self.inner.write(); + for (s, _v) in inner.iter() { + return Some(s.clone()); + } + None + } }