use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct Context {
type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl Context {
pub fn new() -> Self {
Self {
type_map: HashMap::new(),
}
}
pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
where
E: Send + Sync + 'static,
{
self.type_map
.insert(TypeId::of::<E>(), Arc::new(entity))
.map(|displaced| displaced.downcast().expect("failed to unwrap downcast"))
}
pub fn insert<E>(&mut self, entity: E) -> &mut Self
where
E: Send + Sync + 'static,
{
self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
self
}
pub fn remove<E>(&mut self) -> Option<Arc<E>>
where
E: Send + Sync + 'static,
{
self.type_map
.remove(&TypeId::of::<E>())
.map(|removed| removed.downcast().expect("failed to unwrap downcast"))
}
pub fn get<E>(&self) -> Option<&E>
where
E: Send + Sync + 'static,
{
self.type_map
.get(&TypeId::of::<E>())
.and_then(|item| item.downcast_ref())
}
pub fn len(&self) -> usize {
self.type_map.len()
}
pub fn is_empty(&self) -> bool {
self.type_map.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn insert_get_string() {
let mut context = Context::new();
context.insert_or_replace("pollo".to_string());
assert_eq!(Some(&"pollo".to_string()), context.get());
}
#[test]
fn insert_get_custom_structs() {
#[derive(Debug, PartialEq, Eq)]
struct S1 {}
#[derive(Debug, PartialEq, Eq)]
struct S2 {}
let mut context = Context::new();
context.insert_or_replace(S1 {});
context.insert_or_replace(S2 {});
assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
assert_eq!(Some(&S1 {}), context.get());
assert_eq!(Some(&S2 {}), context.get());
}
#[test]
fn insert_fluent_syntax() {
#[derive(Debug, PartialEq, Eq, Default)]
struct S1 {}
#[derive(Debug, PartialEq, Eq, Default)]
struct S2 {}
let mut context = Context::new();
context
.insert("static str")
.insert("a String".to_string())
.insert(S1::default())
.insert(S1::default()) .insert(S2::default());
assert_eq!(4, context.len());
assert_eq!(Some(&"static str"), context.get());
}
fn require_send_sync<T: Send + Sync>(_: &T) {}
#[test]
fn test_require_send_sync() {
require_send_sync(&Context::new());
}
#[test]
fn mutability() {
#[derive(Debug, PartialEq, Eq, Default)]
struct S1 {
num: u8,
}
let mut context = Context::new();
context.insert_or_replace(Mutex::new(S1::default()));
assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
let displaced = context
.insert_or_replace(Mutex::new(S1::default()))
.unwrap();
assert_eq!(42, displaced.lock().unwrap().num);
assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
context.insert_or_replace(Mutex::new(33u32));
*context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
}
}