Feat: add file and vector index cache search

This commit is contained in:
gongzhengyang 2025-09-23 10:03:21 +08:00 committed by 龚正阳
parent 7576daa6cc
commit b16fe21506
14 changed files with 338 additions and 310 deletions

View File

@ -1,2 +1,3 @@
[workspace]
members = ["example", "xdb"]
resolver = "2"
members = ["example", "ip2region"]

View File

@ -11,7 +11,8 @@
# 缓存方式说明
由于基于文件的查询以及缓存`VectorIndex`索引在并发较高(比如每秒上百并发)的情况下,每次查询都会从磁盘加载`ip2region.xdb`文件进入内存,由此会产生很高的磁盘`IO`以及极大的内存占用,所以决定做一次减法,不对这两种缓存进行开发,只提供缓存整个`xdb`文件的方式,以此实现最小的并发查询内存开销以及极限`CPU`性能压榨
由于基于文件的查询以及缓存`VectorIndex`索引在并发较高(比如每秒上百并发)的情况下,查询会从磁盘上的`ip2region.xdb`按需进行`IO`读取,由于
占用内存较低,
# 使用方式

View File

@ -1,16 +1,15 @@
[package]
name = "rust-example"
default-run = "rust-example"
version = "0.1.0"
edition = "2021"
rust-version = "1.66.0"
description = "the rust binding for ip2region"
version = "0.2.0"
edition = "2024"
rust-version = "1.89.0"
description = "Rust binding example for ip2region"
license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
xdb = { path = "../xdb" }
clap = { version = "4.0" }
tracing = "0.1"
tracing-subscriber = "0.3.14"
ip2region = { path = "../ip2region" }
clap = { version = "4.5", features = ["derive", "env"] }
tracing-subscriber = "0.3"

View File

@ -1,27 +1,34 @@
use clap::{Arg, ArgMatches, Command};
use clap::{Parser, Subcommand, ValueEnum};
pub fn get_matches() -> ArgMatches {
let db_arg = Arg::new("db")
.long("db")
.help("the xdb filepath, you can set this field like \
../data/ip2region.xdb,if you dont set,\
if will detect xdb file on ../data/ip2region.xdb, ../../data/ip2region.xdb, ../../../data/ip2region.xdb if exists");
Command::new("ip2region")
.version("0.1")
.about("ip2region bin program")
.long_about("you can set --db in command to specific the xdb filepath, default run query")
.subcommand(Command::new("query").about("query test").arg(&db_arg))
.subcommand(
Command::new("bench")
.about("bench test")
.arg(
Arg::new("src")
.long("src")
.help("set this to specific source bench file")
.required(true),
)
.arg(&db_arg),
)
.get_matches()
/// Rust binding example for ip2region
///
/// `cargo run -- --xdb=../../../data/ip2region_v4.xdb bench ../../../data/ip.test.txt`
///
/// `cargo run -- --xdb=../../../data/ip2region_v4.xdb query`
///
#[derive(Parser)]
pub struct Command {
/// xdb filepath, e.g. `../../../data/ip2region_v4.xdb`
#[arg(long, env = "XDB")]
pub xdb: String,
#[arg(long, value_enum, default_value_t = CmdCachePolicy::FullMemory)]
pub cache_policy: CmdCachePolicy,
#[clap(subcommand)]
pub action: Action,
}
#[derive(Subcommand)]
pub enum Action {
/// Bench the ip search and output performance info
Bench { check_file: String},
/// Interactive input and output, querying one IP and get result at a time
Query,
}
#[derive(Debug, PartialEq, ValueEnum, Clone, Copy, Default)]
pub enum CmdCachePolicy {
#[default]
FullMemory,
NoCache,
VectorIndex,
}

View File

@ -7,26 +7,16 @@ use std::net::Ipv4Addr;
use std::str::FromStr;
use std::time::Instant;
use clap::ArgMatches;
use xdb::{search_by_ip, searcher_init};
use clap::Parser;
use ip2region::{Searcher, CachePolicy};
use crate::cmd::{Action, CmdCachePolicy, Command};
mod cmd;
/// set rust log level, if you don`t want print log, you can skip this
fn log_init() {
let rust_log_key = "RUST_LOG";
std::env::var(rust_log_key).unwrap_or_else(|_| {
std::env::set_var(rust_log_key, "INFO");
std::env::var(rust_log_key).unwrap()
});
tracing_subscriber::fmt::init();
}
fn bench_test(src_filepath: &str) {
fn bench(searcher: &Searcher, check_filepath: &str) {
let now = Instant::now();
let mut count = 0;
let mut file = File::open(src_filepath).unwrap();
let mut file = File::open(check_filepath).unwrap();
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
@ -53,7 +43,7 @@ fn bench_test(src_filepath: &str) {
((mid_ip as u64 + end_ip as u64) >> 1) as u32,
end_ip,
] {
let result = search_by_ip(ip).unwrap();
let result = searcher.search(ip).unwrap();
assert_eq!(result.as_str(), ip_test_line[2]);
count += 1;
}
@ -67,7 +57,7 @@ fn bench_test(src_filepath: &str) {
)
}
fn query_test() {
fn query(searcher: &Searcher) {
println!("ip2region xdb searcher test program, type `quit` or `Ctrl + c` to exit");
loop {
print!("ip2region>> ");
@ -79,32 +69,25 @@ fn query_test() {
}
let line = line.trim();
let now = Instant::now();
let result = search_by_ip(line);
let result = searcher.search(line);
let cost = now.elapsed();
println!("region: {result:?}, took: {cost:?}",);
}
}
fn matches_for_searcher(matches: &ArgMatches) {
if let Some(xdb_filepath) = matches.get_one::<String>("db") {
searcher_init(Some(xdb_filepath.to_owned()))
} else {
searcher_init(None);
}
}
fn main() {
log_init();
let matches = cmd::get_matches();
if let Some(sub_matches) = matches.subcommand_matches("bench") {
matches_for_searcher(sub_matches);
let src_filepath = sub_matches.get_one::<String>("src").unwrap();
tracing_subscriber::fmt::init();
bench_test(src_filepath);
}
if let Some(sub_matches) = matches.subcommand_matches("query") {
matches_for_searcher(sub_matches);
query_test()
let cmd = Command::parse();
let cache_policy = match cmd.cache_policy {
CmdCachePolicy::FullMemory => CachePolicy::FullMemory,
CmdCachePolicy::VectorIndex => CachePolicy::VectorIndex,
CmdCachePolicy::NoCache => CachePolicy::NoCache
};
let searcher = Searcher::new(cmd.xdb, cache_policy);
match cmd.action {
Action::Bench{ check_file} => bench(&searcher, &check_file),
Action::Query => query(&searcher)
}
}

View File

@ -1,21 +1,20 @@
[package]
name = "xdb"
version = "0.1.0"
edition = "2021"
rust-version = "1.66.0"
description = "the rust binding for ip2region"
name = "ip2region"
version = "0.2.0"
edition = "2024"
rust-version = "1.89.0"
description = "The rust binding for ip2region"
license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
once_cell = "1.16"
tracing = "0.1"
tracing-subscriber = "0.3.14"
thiserror = "2"
[dev-dependencies]
criterion = "0.4"
rand = "0.8"
criterion = "0.7"
rand = "0.9"
[[bench]]
name = "search"

View File

@ -0,0 +1,31 @@
use criterion::{Criterion, criterion_group, criterion_main};
use rand;
use ip2region::{CachePolicy, Searcher};
const XDB_FILEPATH: &'static str = "../../../data/ip2region_v4.xdb";
macro_rules! bench_search {
($name:ident, $cache_policy:expr) => {
fn $name(c: &mut Criterion) {
c.bench_function(stringify!($name), |b| {
let searcher = Searcher::new(XDB_FILEPATH.to_owned(), $cache_policy);
b.iter(|| {
searcher.search(rand::random::<u32>()).unwrap();
})
});
}
};
}
bench_search!(no_memory_bench, CachePolicy::NoCache);
bench_search!(vector_index_cache_bench, CachePolicy::VectorIndex);
bench_search!(full_memory_cache_bench, CachePolicy::FullMemory);
criterion_group!(
benches,
no_memory_bench,
vector_index_cache_bench,
full_memory_cache_bench,
);
criterion_main!(benches);

View File

@ -0,0 +1,16 @@
#[derive(Debug, thiserror::Error)]
pub enum Ip2RegionError {
#[error("Io error: {0}")]
IoError(#[from] std::io::Error),
#[error("From UTF-8 error: {0}")]
Utf8Error(#[from] std::string::FromUtf8Error),
#[error("Parse invalid IP address")]
ParseIpaddress(#[from] std::num::ParseIntError),
#[error("No matched Ipaddress")]
NoMatchedIP,
}
pub type Result<T> = std::result::Result<T, Ip2RegionError>;

View File

@ -1,19 +1,20 @@
use std::error::Error;
use std::net::Ipv4Addr;
use std::str::FromStr;
use crate::error::Result;
pub trait ToUIntIP {
fn to_u32_ip(&self) -> Result<u32, Box<dyn Error>>;
fn to_u32_ip(&self) -> Result<u32>;
}
impl ToUIntIP for u32 {
fn to_u32_ip(&self) -> Result<u32, Box<dyn Error>> {
fn to_u32_ip(&self) -> Result<u32> {
Ok(self.to_owned())
}
}
impl ToUIntIP for &str {
fn to_u32_ip(&self) -> Result<u32, Box<dyn Error>> {
fn to_u32_ip(&self) -> Result<u32> {
if let Ok(ip_addr) = Ipv4Addr::from_str(self) {
return Ok(u32::from(ip_addr));
}
@ -22,7 +23,7 @@ impl ToUIntIP for &str {
}
impl ToUIntIP for Ipv4Addr {
fn to_u32_ip(&self) -> Result<u32, Box<dyn Error>> {
fn to_u32_ip(&self) -> Result<u32> {
Ok(u32::from(*self))
}
}

View File

@ -0,0 +1,6 @@
mod error;
mod ip_value;
mod searcher;
pub use self::ip_value::ToUIntIP;
pub use self::searcher::{CachePolicy, Searcher};

View File

@ -0,0 +1,208 @@
use std::borrow::Cow;
use std::fmt::Display;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::sync::OnceLock;
use tracing::{debug, trace, warn};
use crate::ToUIntIP;
use crate::error::{Ip2RegionError, Result};
const HEADER_INFO_LENGTH: usize = 256;
const VECTOR_INDEX_LENGTH: usize = 256 * 256 * 8;
const VECTOR_INDEX_COLS: usize = 256;
const VECTOR_INDEX_SIZE: usize = 8;
const SEGMENT_INDEX_SIZE: usize = 14;
static VECTOR_INDEX_CACHE: OnceLock<Vec<u8>> = OnceLock::new();
static FULL_CACHE: OnceLock<Vec<u8>> = OnceLock::new();
pub struct Searcher {
pub filepath: String,
pub cache_policy: CachePolicy,
}
#[derive(PartialEq, Debug)]
pub enum CachePolicy {
NoCache,
VectorIndex,
FullMemory,
}
impl Searcher {
pub fn new(filepath: String, cache_policy: CachePolicy) -> Self {
Self {
filepath,
cache_policy,
}
}
pub fn search<T>(&self, ip: T) -> Result<String>
where
T: ToUIntIP + Display,
{
let ip = ip.to_u32_ip()?;
let il0 = ((ip >> 24) & 0xFF) as usize;
let il1 = ((ip >> 16) & 0xFF) as usize;
let start_point = VECTOR_INDEX_SIZE * (il0 * VECTOR_INDEX_COLS + il1);
let vector_index = self.vector_index()?;
let start_ptr = get_block_by_size(&vector_index, start_point, 4);
let end_ptr = get_block_by_size(&vector_index, start_point + 4, 4);
let mut left: usize = 0;
let mut right: usize = (end_ptr - start_ptr) / SEGMENT_INDEX_SIZE;
while left <= right {
let mid = (left + right) >> 1;
let offset = start_ptr + mid * SEGMENT_INDEX_SIZE;
let buffer_ip_value = self.read_buf(offset, SEGMENT_INDEX_SIZE)?;
let start_ip = get_block_by_size(&buffer_ip_value, 0, 4);
if ip < (start_ip as u32) {
right = mid - 1;
} else if ip > (get_block_by_size(&buffer_ip_value, 4, 4) as u32) {
left = mid + 1;
} else {
let data_length = get_block_by_size(&buffer_ip_value, 8, 2);
let data_offset = get_block_by_size(&buffer_ip_value, 10, 4);
let result = String::from_utf8(self.read_buf(data_offset, data_length)?.to_vec())?;
return Ok(result);
}
}
Err(Ip2RegionError::NoMatchedIP)
}
pub fn vector_index(&self) -> Result<Cow<'_, [u8]>> {
if self.cache_policy.eq(&CachePolicy::NoCache) {
return self.read_buf(HEADER_INFO_LENGTH, VECTOR_INDEX_LENGTH);
}
match VECTOR_INDEX_CACHE.get() {
None => {
debug!("Load vector index cache");
let data = self
.read_buf(HEADER_INFO_LENGTH, VECTOR_INDEX_LENGTH)?
.to_vec();
let _ = VECTOR_INDEX_CACHE
.set(data)
.inspect_err(|_| warn!("Vector index cache already initialized"));
// Safety: VECTOR_INDEX_CACHE checked and set for empty before
let cache = VECTOR_INDEX_CACHE.get().unwrap();
Ok(Cow::Borrowed(cache))
}
Some(cache) => Ok(Cow::Borrowed(cache)),
}
}
pub fn read_buf(&self, offset: usize, size: usize) -> Result<Cow<'_, [u8]>> {
trace!(offset, size = size, "Read buffer");
if self.cache_policy.ne(&CachePolicy::FullMemory) {
debug!(filepath=?self.filepath, offset=offset, size=size, "Read buf without cache");
let mut file = File::open(&self.filepath)?;
file.seek(SeekFrom::Start(offset as u64))?;
let mut buf = vec![0u8; size];
file.take(size as u64).read_exact(&mut buf)?;
return Ok(Cow::from(buf));
}
match FULL_CACHE.get() {
None => {
debug!(filepath=?self.filepath, "Load full cache");
let mut file = File::open(&self.filepath)?;
let mut buf = Vec::new();
file.read_to_end(&mut buf)?;
let _ = FULL_CACHE
.set(buf)
.inspect_err(|_| warn!("Full cache already initialized"));
// Safety: FULL_CACHE checked and set for empty before
let cache = FULL_CACHE.get().unwrap();
Ok(Cow::from(&cache[offset..offset + size]))
}
Some(cache) => {
let data = Cow::from(&cache[offset..offset + size]);
Ok(data)
}
}
}
}
#[inline]
pub fn get_block_by_size(bytes: &[u8], offset: usize, length: usize) -> usize {
let mut result: usize = 0;
for (index, value) in bytes[offset..offset + length].iter().enumerate() {
result += usize::from(*value) << (index << 3);
}
result
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::net::Ipv4Addr;
use std::str::FromStr;
use super::*;
const XDB_PATH: &str = "../../../data/ip2region_v4.xdb";
const CHECK_PATH: &str = "../../../data/ipv4_source.txt";
fn multi_type_ip(searcher: &Searcher) {
searcher.search("2.0.0.0").unwrap();
searcher.search("32").unwrap();
searcher.search(4294408949).unwrap();
searcher
.search(Ipv4Addr::from_str("1.1.1.1").unwrap())
.unwrap();
}
///test all types find correct
#[test]
fn test_multi_type_ip() {
for cache_policy in [
CachePolicy::NoCache,
CachePolicy::FullMemory,
CachePolicy::VectorIndex,
] {
multi_type_ip(&Searcher::new(XDB_PATH.to_owned(), cache_policy));
}
}
fn match_ip_correct(searcher: &Searcher) {
let file = File::open(CHECK_PATH).unwrap();
let reader = BufReader::new(file);
for line in reader.lines().take(100) {
let line = line.unwrap();
if !line.contains("|") {
continue;
}
let ip_test_line = line.splitn(3, "|").collect::<Vec<&str>>();
let start_ip = Ipv4Addr::from_str(ip_test_line[0]).unwrap();
let end_ip = Ipv4Addr::from_str(ip_test_line[1]).unwrap();
for _ in 0..10 {
let value = rand::random_range(u32::from(start_ip)..u32::from(end_ip) + 1);
let result = searcher.search(value).unwrap();
assert_eq!(result.as_str(), ip_test_line[2])
}
}
}
#[test]
fn test_match_ip_correct() {
for cache_policy in [
CachePolicy::NoCache,
CachePolicy::FullMemory,
CachePolicy::VectorIndex,
] {
match_ip_correct(&Searcher::new(XDB_PATH.to_owned(), cache_policy));
}
}
}

View File

@ -1,52 +0,0 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand;
use xdb::searcher::{
get_block_by_size, get_full_cache, get_vector_index_cache, search_by_ip, searcher_init,
};
fn search_by_ip_bench(c: &mut Criterion) {
c.bench_function("search_by_ip_bench", |b| {
searcher_init(None);
b.iter(|| {
search_by_ip(rand::random::<u32>()).unwrap();
})
});
}
fn get_block_by_size_bench(c: &mut Criterion) {
c.bench_function("get_block_by_size_bench", |b| {
b.iter(|| {
black_box(get_block_by_size(
get_full_cache(),
rand::random::<u16>() as usize,
4,
));
})
});
}
fn get_full_cache_bench(c: &mut Criterion) {
c.bench_function("get_full_cache_bench", |b| {
b.iter(|| {
black_box(get_full_cache());
})
});
}
fn get_vec_index_cache_bench(c: &mut Criterion) {
c.bench_function("get_vec_index_cache_bench", |b| {
b.iter(|| {
black_box(get_vector_index_cache());
})
});
}
criterion_group!(
benches,
search_by_ip_bench,
get_block_by_size_bench,
get_full_cache_bench,
get_vec_index_cache_bench,
);
criterion_main!(benches);

View File

@ -1,4 +0,0 @@
mod ip_value;
pub use self::ip_value::ToUIntIP;
pub mod searcher;
pub use searcher::{search_by_ip, searcher_init};

View File

@ -1,168 +0,0 @@
use std::error::Error;
use std::fmt::Display;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use once_cell::sync::OnceCell;
use crate::ToUIntIP;
const HEADER_INFO_LENGTH: usize = 256;
const VECTOR_INDEX_COLS: usize = 256;
const VECTOR_INDEX_SIZE: usize = 8;
const SEGMENT_INDEX_SIZE: usize = 14;
const VECTOR_INDEX_LENGTH: usize = 512 * 1024;
const XDB_FILEPATH_ENV: &str = "XDB_FILEPATH";
static CACHE: OnceCell<Vec<u8>> = OnceCell::new();
/// check https://mp.weixin.qq.com/s/ndjzu0BgaeBmDOCw5aqHUg for details
pub fn search_by_ip<T>(ip: T) -> Result<String, Box<dyn Error>>
where
T: ToUIntIP + Display,
{
let ip = ip.to_u32_ip()?;
let il0 = ((ip >> 24) & 0xFF) as usize;
let il1 = ((ip >> 16) & 0xFF) as usize;
let idx = VECTOR_INDEX_SIZE * (il0 * VECTOR_INDEX_COLS + il1);
let start_point = idx;
let vector_cache = get_vector_index_cache();
let start_ptr = get_block_by_size(vector_cache, start_point, 4);
let end_ptr = get_block_by_size(vector_cache, start_point + 4, 4);
let mut left: usize = 0;
let mut right: usize = (end_ptr - start_ptr) / SEGMENT_INDEX_SIZE;
while left <= right {
let mid = (left + right) >> 1;
let offset = start_ptr + mid * SEGMENT_INDEX_SIZE;
let buffer_ip_value = &get_full_cache()[offset..offset + SEGMENT_INDEX_SIZE];
let start_ip = get_block_by_size(buffer_ip_value, 0, 4);
if ip < (start_ip as u32) {
right = mid - 1;
} else if ip > (get_block_by_size(buffer_ip_value, 4, 4) as u32) {
left = mid + 1;
} else {
let data_length = get_block_by_size(buffer_ip_value, 8, 2);
let data_offset = get_block_by_size(buffer_ip_value, 10, 4);
let result = String::from_utf8(
get_full_cache()[data_offset..(data_offset + data_length)].to_vec(),
);
return Ok(result?);
}
}
Err("not matched".into())
}
/// it will check ../data/ip2region.xdb, ../../data/ip2region.xdb, ../../../data/ip2region.xdb
fn default_detect_xdb_file() -> Result<String, Box<dyn Error>> {
let prefix = "../".to_owned();
for recurse in 1..4 {
let filepath = prefix.repeat(recurse) + "data/ip2region.xdb";
if Path::new(filepath.as_str()).exists() {
return Ok(filepath);
}
}
Err("default filepath not find the xdb file, so you must set xdb_filepath".into())
}
#[inline]
pub fn get_block_by_size(bytes: &[u8], offset: usize, length: usize) -> usize {
let mut result: usize = 0;
for (index, value) in bytes[offset..offset + length].iter().enumerate() {
result += usize::from(*value) << (index << 3);
}
result
}
pub fn searcher_init(xdb_filepath: Option<String>) {
let xdb_filepath = xdb_filepath.unwrap_or_else(|| default_detect_xdb_file().unwrap());
std::env::set_var(XDB_FILEPATH_ENV, xdb_filepath);
CACHE.get_or_init(load_file);
}
pub fn get_vector_index_cache() -> &'static [u8] {
let full_cache: &'static Vec<u8> = get_full_cache();
&full_cache[HEADER_INFO_LENGTH..(HEADER_INFO_LENGTH + VECTOR_INDEX_LENGTH)]
}
fn load_file() -> Vec<u8> {
let xdb_filepath =
std::env::var("XDB_FILEPATH").unwrap_or_else(|_| default_detect_xdb_file().unwrap());
tracing::debug!("load xdb searcher file at {} ", xdb_filepath);
let mut f = File::open(xdb_filepath).expect("file open error");
let mut buffer = Vec::new();
f.read_to_end(&mut buffer).expect("load file error");
buffer
}
pub fn get_full_cache() -> &'static Vec<u8> {
CACHE.get_or_init(load_file)
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Read;
use std::net::Ipv4Addr;
use std::str::FromStr;
use std::thread;
use super::*;
///test all types find correct
#[test]
fn test_multi_type_ip() {
searcher_init(None);
search_by_ip("2.0.0.0").unwrap();
search_by_ip("32").unwrap();
search_by_ip(4294408949).unwrap();
search_by_ip(Ipv4Addr::from_str("1.1.1.1").unwrap()).unwrap();
}
#[test]
fn test_match_all_ip_correct() {
searcher_init(None);
let mut file = File::open("../../../data/ip.test.txt").unwrap();
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
for line in contents.split("\n") {
if !line.contains("|") {
continue;
}
let ip_test_line = line.splitn(3, "|").collect::<Vec<&str>>();
let start_ip = Ipv4Addr::from_str(ip_test_line[0]).unwrap();
let end_ip = Ipv4Addr::from_str(ip_test_line[1]).unwrap();
for value in u32::from(start_ip)..u32::from(end_ip) + 1 {
let result = search_by_ip(value).unwrap();
assert_eq!(result.as_str(), ip_test_line[2])
}
}
}
#[test]
fn test_multi_thread_only_load_xdb_once() {
searcher_init(None);
let handle = thread::spawn(|| {
let result = search_by_ip("2.2.2.2").unwrap();
println!("ip search in spawn: {result}");
});
let r = search_by_ip("1.1.1.1").unwrap();
println!("ip search in main thread: {r}");
handle.join().unwrap();
}
#[test]
fn test_multi_searcher_init() {
for _ in 0..5 {
thread::spawn(|| {
searcher_init(None);
});
}
searcher_init(None);
searcher_init(Some(String::from("test")));
search_by_ip(123).unwrap();
}
}