diff --git a/README.md b/README.md index 1086988..23d8a35 100644 --- a/README.md +++ b/README.md @@ -34,4 +34,4 @@ The tutorial has 8 parts (which can be finished in 7 days): * Day 6: Recovery. We will implement WAL and manifest so that the engine can recover after restart. * Day 7: Bloom filter and key compression. They are widely-used optimizations in LSM tree structures. -We have reference solution up to day 3 and tutorial up to day 2 for now. +We have reference solution up to day 4 and tutorial up to day 2 for now. diff --git a/mini-lsm/src/lsm_storage.rs b/mini-lsm/src/lsm_storage.rs index cf9144c..4c134df 100644 --- a/mini-lsm/src/lsm_storage.rs +++ b/mini-lsm/src/lsm_storage.rs @@ -16,16 +16,19 @@ use crate::table::{SsTable, SsTableBuilder, SsTableIterator}; #[derive(Clone)] pub struct LsmStorageInner { - /// MemTables, from oldest to earliest. - memtables: Vec>, - /// L0 SsTables, from oldest to earliest. + /// The current memtable. + memtable: Arc, + /// Immutable memTables, from earliest to latest. + imm_memtables: Vec>, + /// L0 SsTables, from earliest to latest. l0_sstables: Vec>, } impl LsmStorageInner { fn create() -> Self { Self { - memtables: vec![Arc::new(MemTable::create())], + memtable: Arc::new(MemTable::create()), + imm_memtables: vec![], l0_sstables: vec![], } } @@ -47,8 +50,17 @@ impl LsmStorage { pub fn get(&self, key: &[u8]) -> Result> { let snapshot = self.inner.load(); - for memtable in snapshot.memtables.iter().rev() { - if let Some(value) = memtable.get(key)? { + // Search on the current memtable. + if let Some(value) = snapshot.memtable.get(key) { + if value.is_empty() { + // found tomestone, return key not exists + return Ok(None); + } + return Ok(Some(value)); + } + // Search on immutable memtables. + for memtable in snapshot.imm_memtables.iter().rev() { + if let Some(value) = memtable.get(key) { if value.is_empty() { // found tomestone, return key not exists return Ok(None); @@ -74,30 +86,67 @@ impl LsmStorage { pub fn put(&self, key: &[u8], value: &[u8]) -> Result<()> { assert!(!value.is_empty(), "value cannot be empty"); assert!(!key.is_empty(), "key cannot be empty"); - let snapshot = self.inner.load(); - snapshot.memtables[0].put(key, value)?; + loop { + let snapshot = self.inner.load(); + if snapshot.memtable.put(key, value) { + break; + } + // waiting for a new memtable to be propagated + } Ok(()) } pub fn delete(&self, key: &[u8]) -> Result<()> { - let snapshot = self.inner.load(); - snapshot.memtables[0].put(key, b"")?; + assert!(!key.is_empty(), "key cannot be empty"); + loop { + let snapshot = self.inner.load(); + if snapshot.memtable.put(key, b"") { + break; + } + // waiting for a new memtable to be propagated + } Ok(()) } pub fn sync(&self) -> Result<()> { let _flush_lock = self.flush_lock.lock(); - let mut snapshot = { - let snapshot = self.inner.load(); - snapshot.as_ref().clone() - }; + + let flush_memtable; + + // Move mutable memtable to immutable memtables. + { + let guard = self.inner.load(); + // Swap the current memtable with a new one. + let mut snapshot = guard.as_ref().clone(); + let memtable = std::mem::replace(&mut snapshot.memtable, Arc::new(MemTable::create())); + flush_memtable = memtable.clone(); + // Add the memtable to the immutable memtables. + snapshot.imm_memtables.push(memtable.clone()); + // Disable the memtable. + memtable.seal(); + // Update the snapshot. + self.inner.store(Arc::new(snapshot)); + } + + // At this point, the old memtable should be disabled for write, and all threads should be + // operating on the new memtable. We can safely flush the old memtable to disk. let mut builder = SsTableBuilder::new(4096); - let memtable = snapshot.memtables.pop().unwrap(); - assert!(snapshot.memtables.is_empty()); - memtable.flush(&mut builder)?; - snapshot.l0_sstables.push(Arc::new(builder.build("")?)); - self.inner.store(Arc::new(snapshot)); + flush_memtable.flush(&mut builder)?; + let sst = Arc::new(builder.build("")?); + + // Add the flushed L0 table to the list. + { + let guard = self.inner.load(); + let mut snapshot = guard.as_ref().clone(); + // Remove the memtable from the immutable memtables. + snapshot.imm_memtables.pop(); + // Add L0 table + snapshot.l0_sstables.push(sst); + // Update the snapshot. + self.inner.store(Arc::new(snapshot)); + } + Ok(()) } @@ -109,8 +158,9 @@ impl LsmStorage { let snapshot = self.inner.load(); let mut memtable_iters = Vec::new(); - memtable_iters.reserve(snapshot.memtables.len()); - for memtable in snapshot.memtables.iter().rev() { + memtable_iters.reserve(snapshot.imm_memtables.len() + 1); + memtable_iters.push(Box::new(snapshot.memtable.scan(lower, upper)?)); + for memtable in snapshot.imm_memtables.iter().rev() { memtable_iters.push(Box::new(memtable.scan(lower, upper)?)); } let memtable_iter = MergeIterator::create(memtable_iters); diff --git a/mini-lsm/src/mem_table.rs b/mini-lsm/src/mem_table.rs index 17801d8..edb2076 100644 --- a/mini-lsm/src/mem_table.rs +++ b/mini-lsm/src/mem_table.rs @@ -1,4 +1,5 @@ use std::ops::Bound; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use anyhow::Result; @@ -13,6 +14,7 @@ use crate::table::SsTableBuilder; /// A basic mem-table based on crossbeam-skiplist pub struct MemTable { map: Arc>, + sealed: AtomicBool, } pub(crate) fn map_bound(bound: Bound<&[u8]>) -> Bound { @@ -28,20 +30,24 @@ impl MemTable { pub fn create() -> Self { Self { map: Arc::new(SkipMap::new()), + sealed: AtomicBool::new(false), } } /// Get a value by key. - pub fn get(&self, key: &[u8]) -> Result> { - let entry = self.map.get(key).map(|e| e.value().clone()); - Ok(entry) + pub fn get(&self, key: &[u8]) -> Option { + self.map.get(key).map(|e| e.value().clone()) } - /// Put a key-value pair into the mem-table. - pub fn put(&self, key: &[u8], value: &[u8]) -> Result<()> { + /// Put a key-value pair into the mem-table. If the current mem-table is sealed, return false. + pub fn put(&self, key: &[u8], value: &[u8]) -> bool { + use std::sync::atomic::Ordering; + if self.sealed.load(Ordering::Acquire) { + return false; + } self.map .insert(Bytes::copy_from_slice(key), Bytes::copy_from_slice(value)); - Ok(()) + true } /// Get an iterator over a range of keys. @@ -65,6 +71,12 @@ impl MemTable { } Ok(()) } + + /// Disable writes to this memtable. + pub(crate) fn seal(&self) { + use std::sync::atomic::Ordering; + self.sealed.store(true, Ordering::Release); + } } type SkipMapRangeIter<'a> = diff --git a/mini-lsm/src/mem_table/tests.rs b/mini-lsm/src/mem_table/tests.rs index 4b1f7d0..e711a0a 100644 --- a/mini-lsm/src/mem_table/tests.rs +++ b/mini-lsm/src/mem_table/tests.rs @@ -5,34 +5,34 @@ use crate::table::{SsTableBuilder, SsTableIterator}; #[test] fn test_memtable_get() { let memtable = MemTable::create(); - memtable.put(b"key1", b"value1").unwrap(); - memtable.put(b"key2", b"value2").unwrap(); - memtable.put(b"key3", b"value3").unwrap(); - assert_eq!(&memtable.get(b"key1").unwrap().unwrap()[..], b"value1"); - assert_eq!(&memtable.get(b"key2").unwrap().unwrap()[..], b"value2"); - assert_eq!(&memtable.get(b"key3").unwrap().unwrap()[..], b"value3"); + memtable.put(b"key1", b"value1"); + memtable.put(b"key2", b"value2"); + memtable.put(b"key3", b"value3"); + assert_eq!(&memtable.get(b"key1").unwrap()[..], b"value1"); + assert_eq!(&memtable.get(b"key2").unwrap()[..], b"value2"); + assert_eq!(&memtable.get(b"key3").unwrap()[..], b"value3"); } #[test] fn test_memtable_overwrite() { let memtable = MemTable::create(); - memtable.put(b"key1", b"value1").unwrap(); - memtable.put(b"key2", b"value2").unwrap(); - memtable.put(b"key3", b"value3").unwrap(); - memtable.put(b"key1", b"value11").unwrap(); - memtable.put(b"key2", b"value22").unwrap(); - memtable.put(b"key3", b"value33").unwrap(); - assert_eq!(&memtable.get(b"key1").unwrap().unwrap()[..], b"value11"); - assert_eq!(&memtable.get(b"key2").unwrap().unwrap()[..], b"value22"); - assert_eq!(&memtable.get(b"key3").unwrap().unwrap()[..], b"value33"); + memtable.put(b"key1", b"value1"); + memtable.put(b"key2", b"value2"); + memtable.put(b"key3", b"value3"); + memtable.put(b"key1", b"value11"); + memtable.put(b"key2", b"value22"); + memtable.put(b"key3", b"value33"); + assert_eq!(&memtable.get(b"key1").unwrap()[..], b"value11"); + assert_eq!(&memtable.get(b"key2").unwrap()[..], b"value22"); + assert_eq!(&memtable.get(b"key3").unwrap()[..], b"value33"); } #[test] fn test_memtable_flush() { let memtable = MemTable::create(); - memtable.put(b"key1", b"value1").unwrap(); - memtable.put(b"key2", b"value2").unwrap(); - memtable.put(b"key3", b"value3").unwrap(); + memtable.put(b"key1", b"value1"); + memtable.put(b"key2", b"value2"); + memtable.put(b"key3", b"value3"); let mut builder = SsTableBuilder::new(128); memtable.flush(&mut builder).unwrap(); let sst = builder.build("").unwrap(); @@ -53,9 +53,9 @@ fn test_memtable_flush() { fn test_memtable_iter() { use std::ops::Bound; let memtable = MemTable::create(); - memtable.put(b"key1", b"value1").unwrap(); - memtable.put(b"key2", b"value2").unwrap(); - memtable.put(b"key3", b"value3").unwrap(); + memtable.put(b"key1", b"value1"); + memtable.put(b"key2", b"value2"); + memtable.put(b"key3", b"value3"); { let mut iter = memtable.scan(Bound::Unbounded, Bound::Unbounded).unwrap(); diff --git a/mini-lsm/src/tests/day3_tests.rs b/mini-lsm/src/tests/day3_tests.rs index 1189910..b591ca3 100644 --- a/mini-lsm/src/tests/day3_tests.rs +++ b/mini-lsm/src/tests/day3_tests.rs @@ -105,3 +105,81 @@ fn test_storage_scan_memtable_2() { vec![(Bytes::from("2"), Bytes::from("2333"))], ); } + +#[test] +fn test_storage_get_after_sync() { + use crate::lsm_storage::LsmStorage; + + let storage = LsmStorage::open("").unwrap(); + storage.put(b"1", b"233").unwrap(); + storage.put(b"2", b"2333").unwrap(); + storage.sync().unwrap(); + storage.put(b"3", b"23333").unwrap(); + assert_eq!(&storage.get(b"1").unwrap().unwrap()[..], b"233"); + assert_eq!(&storage.get(b"2").unwrap().unwrap()[..], b"2333"); + assert_eq!(&storage.get(b"3").unwrap().unwrap()[..], b"23333"); + storage.delete(b"2").unwrap(); + assert!(storage.get(b"2").unwrap().is_none()); +} + +#[test] +fn test_storage_scan_memtable_1_after_sync() { + use crate::lsm_storage::LsmStorage; + + let storage = LsmStorage::open("").unwrap(); + storage.put(b"1", b"233").unwrap(); + storage.put(b"2", b"2333").unwrap(); + storage.sync().unwrap(); + storage.put(b"3", b"23333").unwrap(); + storage.delete(b"2").unwrap(); + check_iter_result( + storage.scan(Bound::Unbounded, Bound::Unbounded).unwrap(), + vec![ + (Bytes::from("1"), Bytes::from("233")), + (Bytes::from("3"), Bytes::from("23333")), + ], + ); + check_iter_result( + storage + .scan(Bound::Included(b"1"), Bound::Included(b"2")) + .unwrap(), + vec![(Bytes::from("1"), Bytes::from("233"))], + ); + check_iter_result( + storage + .scan(Bound::Excluded(b"1"), Bound::Excluded(b"3")) + .unwrap(), + vec![], + ); +} + +#[test] +fn test_storage_scan_memtable_2_after_sync() { + use crate::lsm_storage::LsmStorage; + + let storage = LsmStorage::open("").unwrap(); + storage.put(b"1", b"233").unwrap(); + storage.put(b"2", b"2333").unwrap(); + storage.put(b"3", b"23333").unwrap(); + storage.sync().unwrap(); + storage.delete(b"1").unwrap(); + check_iter_result( + storage.scan(Bound::Unbounded, Bound::Unbounded).unwrap(), + vec![ + (Bytes::from("2"), Bytes::from("2333")), + (Bytes::from("3"), Bytes::from("23333")), + ], + ); + check_iter_result( + storage + .scan(Bound::Included(b"1"), Bound::Included(b"2")) + .unwrap(), + vec![(Bytes::from("2"), Bytes::from("2333"))], + ); + check_iter_result( + storage + .scan(Bound::Excluded(b"1"), Bound::Excluded(b"3")) + .unwrap(), + vec![(Bytes::from("2"), Bytes::from("2333"))], + ); +}