X Tutup
Skip to content

Commit 2adc355

Browse files
committed
Fix type/instruction cache races with atomic ops
- Type cache: use try_to_owned (CAS-based safe_inc) instead of ManuallyDrop+clone to prevent use-after-free when a concurrent writer drops the cached value between pointer load and refcount increment. - Instruction cache: replace split 4×u16 read/write (torn read/write under concurrent specialization) with single AtomicUsize load/store via pointer_cache side array. - Instruction cache reads: use try_to_owned with deoptimize fallback instead of blind to_owned on cached descriptor pointers. - Call modified() BEFORE attribute modification in set_attr and set___type_params__ so cached descriptor pointers remain alive when type_cache_clear_version drops cache entries.
1 parent 223cf12 commit 2adc355

File tree

3 files changed

+147
-123
lines changed

3 files changed

+147
-123
lines changed

crates/compiler-core/src/bytecode.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use core::{
1212
cell::UnsafeCell,
1313
hash, mem,
1414
ops::Deref,
15-
sync::atomic::{AtomicU8, AtomicU16, Ordering},
15+
sync::atomic::{AtomicU8, AtomicU16, AtomicUsize, Ordering},
1616
};
1717
use itertools::Itertools;
1818
use malachite_bigint::BigInt;
@@ -411,6 +411,10 @@ impl TryFrom<&[u8]> for CodeUnit {
411411
pub struct CodeUnits {
412412
units: UnsafeCell<Box<[CodeUnit]>>,
413413
adaptive_counters: Box<[AtomicU16]>,
414+
/// Pointer-sized cache entries for descriptor pointers.
415+
/// Single atomic load/store prevents torn reads when multiple threads
416+
/// specialize the same instruction concurrently.
417+
pointer_cache: Box<[AtomicUsize]>,
414418
}
415419

416420
// SAFETY: All cache operations use atomic read/write instructions.
@@ -432,9 +436,15 @@ impl Clone for CodeUnits {
432436
.iter()
433437
.map(|c| AtomicU16::new(c.load(Ordering::Relaxed)))
434438
.collect();
439+
let pointer_cache = self
440+
.pointer_cache
441+
.iter()
442+
.map(|c| AtomicUsize::new(c.load(Ordering::Relaxed)))
443+
.collect();
435444
Self {
436445
units: UnsafeCell::new(units),
437446
adaptive_counters,
447+
pointer_cache,
438448
}
439449
}
440450
}
@@ -472,13 +482,19 @@ impl<const N: usize> From<[CodeUnit; N]> for CodeUnits {
472482
impl From<Vec<CodeUnit>> for CodeUnits {
473483
fn from(value: Vec<CodeUnit>) -> Self {
474484
let units = value.into_boxed_slice();
475-
let adaptive_counters = (0..units.len())
485+
let len = units.len();
486+
let adaptive_counters = (0..len)
476487
.map(|_| AtomicU16::new(0))
477488
.collect::<Vec<_>>()
478489
.into_boxed_slice();
490+
let pointer_cache = (0..len)
491+
.map(|_| AtomicUsize::new(0))
492+
.collect::<Vec<_>>()
493+
.into_boxed_slice();
479494
Self {
480495
units: UnsafeCell::new(units),
481496
adaptive_counters,
497+
pointer_cache,
482498
}
483499
}
484500
}
@@ -600,25 +616,22 @@ impl CodeUnits {
600616
lo | (hi << 16)
601617
}
602618

603-
/// Write a u64 value across four consecutive CACHE code units starting at `index`.
619+
/// Store a pointer-sized value atomically in the pointer cache at `index`.
620+
///
621+
/// Uses a single `AtomicUsize` store to prevent torn writes when
622+
/// multiple threads specialize the same instruction concurrently.
604623
///
605624
/// # Safety
606-
/// Same requirements as `write_cache_u16`.
607-
pub unsafe fn write_cache_u64(&self, index: usize, value: u64) {
608-
unsafe {
609-
self.write_cache_u32(index, value as u32);
610-
self.write_cache_u32(index + 2, (value >> 32) as u32);
611-
}
625+
/// - `index` must be in bounds.
626+
pub unsafe fn write_cache_ptr(&self, index: usize, value: usize) {
627+
self.pointer_cache[index].store(value, Ordering::Relaxed);
612628
}
613629

614-
/// Read a u64 value from four consecutive CACHE code units starting at `index`.
630+
/// Load a pointer-sized value atomically from the pointer cache at `index`.
615631
///
616-
/// # Panics
617-
/// Panics if `index + 3` is out of bounds.
618-
pub fn read_cache_u64(&self, index: usize) -> u64 {
619-
let lo = self.read_cache_u32(index) as u64;
620-
let hi = self.read_cache_u32(index + 2) as u64;
621-
lo | (hi << 32)
632+
/// Uses a single `AtomicUsize` load to prevent torn reads.
633+
pub fn read_cache_ptr(&self, index: usize) -> usize {
634+
self.pointer_cache[index].load(Ordering::Relaxed)
622635
}
623636

624637
/// Read adaptive counter bits for instruction at `index`.

crates/vm/src/builtins/type.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,11 @@ impl PyType {
701701
}
702702

703703
pub fn set_attr(&self, attr_name: &'static PyStrInterned, value: PyObjectRef) {
704-
self.attributes.write().insert(attr_name, value);
704+
// Invalidate caches BEFORE modifying attributes so that cached
705+
// descriptor pointers are still alive when type_cache_clear_version
706+
// drops the cache's strong references.
705707
self.modified();
708+
self.attributes.write().insert(attr_name, value);
706709
}
707710

708711
/// Internal get_attr implementation for fast lookup on a class.
@@ -735,23 +738,20 @@ impl PyType {
735738
{
736739
let ptr = entry.value.load(Ordering::Acquire);
737740
if !ptr.is_null() {
738-
// SAFETY: The value pointer was stored via PyObjectRef::into_raw
739-
// and is valid as long as the version hasn't changed. We create
740-
// a temporary reference (ManuallyDrop prevents decrement), clone
741-
// it to get our own strong reference, then re-check the version
742-
// to confirm the entry wasn't invalidated during our read.
743-
let cloned = unsafe {
744-
let tmp = core::mem::ManuallyDrop::new(PyObjectRef::from_raw(
745-
NonNull::new_unchecked(ptr),
746-
));
747-
(*tmp).clone()
748-
};
749-
// SeqLock validation: if version changed, discard our clone
750-
let v2 = entry.version.load(Ordering::Acquire);
751-
if v2 == v1 {
752-
return Some(cloned);
741+
// Use try_to_owned (CAS-based safe_inc) to prevent a
742+
// use-after-free race: between loading the pointer and
743+
// incrementing the refcount, a concurrent writer may
744+
// invalidate the entry and drop the old value. safe_inc
745+
// fails atomically if the refcount has already reached 0.
746+
let obj: &PyObject = unsafe { &*ptr };
747+
if let Some(cloned) = obj.try_to_owned() {
748+
// SeqLock validation: if version changed, discard
749+
let v2 = entry.version.load(Ordering::Acquire);
750+
if v2 == v1 {
751+
return Some(cloned);
752+
}
753+
drop(cloned);
753754
}
754-
drop(cloned);
755755
}
756756
}
757757
}
@@ -1498,8 +1498,8 @@ impl PyType {
14981498
PySetterValue::Assign(ref val) => {
14991499
let key = identifier!(vm, __type_params__);
15001500
self.check_set_special_type_attr(key, vm)?;
1501-
self.attributes.write().insert(key, val.clone().into());
15021501
self.modified();
1502+
self.attributes.write().insert(key, val.clone().into());
15031503
}
15041504
PySetterValue::Delete => {
15051505
// For delete, we still need to check if the type is immutable
@@ -1510,8 +1510,8 @@ impl PyType {
15101510
)));
15111511
}
15121512
let key = identifier!(vm, __type_params__);
1513-
self.attributes.write().shift_remove(&key);
15141513
self.modified();
1514+
self.attributes.write().shift_remove(&key);
15151515
}
15161516
}
15171517
Ok(())

0 commit comments

Comments
 (0)
X Tutup