本文发表于入职啦(公众号: ruzhila) 大家可以访问入职啦学习更多的编程实战。
项目地址
代码已经开源, mget_rust 👏 Star
代码运行效果:
如何实现多线程下载
HTTP通过在HEAD添加Range头,可以实现分片下载,这样就可以实现多线程下载。
- 在下载之前先通过HEAD请求获取文件的大小,然后根据文件的大小,分配线程的下载区间,然后每个线程下载对应的区间,就可以实现多线程下载
直接上代码
代码解析
如何初始化项目
我们通过cargo初始化项目,并且我们会引入一些依赖,比如reqwest用于HTTP请求,clap用于命令行参数解析。
cargo new mget_rs
cd mget_rs
cargo add reqwest
cargo add clap
同时要修改Cargo.toml文件,默认的reqwest是基于tokio的,我们需要修改为:
[dependencies]
clap = { version = "4.4.17", features = ["derive"] }
reqwest = { version = "0.12.5", features = ["blocking"] }
命令行参数解析
clap是非常好用的命令行参数解析库,我们可以通过clap解析命令行参数。
#[derive(Parser, Debug)]
#[command(version)]
struct Cli {
#[clap(long, short, default_value = "2")]
threads: usize,
#[clap(long, short)]
output: Option<String>,
#[clap(long, short, default_value = "false")]
verbose: bool,
url: String,
}
这样我们就可以实现命令行参数的解析:
- -t/--thread 用于指定线程数
- -o/--output 用于指定输出文件
- -v/--verbose 用于打印详细信息
实现文件名的处理
如果url上没有文件名,我们需要通过url解析出文件名,默认是index.html:
let file_name = match output {
Some(name) => name.to_string(),
None => {
let url = Url::parse(url).map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
url.path_segments()
.and_then(|segments| segments.last())
.and_then(|name| if name.is_empty() { None } else { Some(name) })
.unwrap_or("index.html")
.to_string()
}
};
并且如果要创建文件,不能直接覆盖,我们需要判断文件是否存在,我们可以创建新的文件名:
// try rename the file to avoid conflict
let mut file_name = file_name;
let mut index = 1;
while std::fs::metadata(&file_name).is_ok() {
let parts: Vec<&str> = file_name.rsplitn(2, '.').collect();
if parts.len() == 2 {
file_name = format!("{}.{}.{}", parts[1], index, parts[0]);
} else {
file_name = format!("{}.{}", file_name, index);
}
index += 1;
}
多线程下载
这部分会比较麻烦,我们需要实现至少1个下载线程去执行下载任务,然后主线程等待所有的下载线程完成。
这时候,我们就需要引入一个mpsc的channel,用于线程之间的通信。 mpsc 的意思就是多个生产者,单个消费者,主线程通过等子线程的状态,来判断是否所有的线程都已经完成或者更新任务:
// 我们先定义好线程之间的数据格式
enum TaskResult {
// 下载中
Downloading(usize, u64, Box<[u8]>), // 线程id,下载的大小,下载的数据
// 下载失败
Failed(usize, Error), // 线程id,错误信息
// 下载完成
Done(usize), // 线程id
}
多线程之间的通信
mpsc需要用TaskResult进行通信,我们看看多线程之间的通信:
//主线程启动多个下载线程
let (tx, rx) = std::sync::mpsc::channel::<TaskResult>();
for idx in 0..threads {
// 实现分片分配
let pos = idx as u64 * file_size / threads as u64;
let length = if idx == threads - 1 {
file_size - pos
} else {
file_size / threads as u64
};
let url = url.to_string();
let tx = tx.clone();
spawn(move || download_part(tx.clone(), url, idx, pos, length));
}
子线程执行下载任务
实现下载逻辑,并且将结果通过tx发送给主线程
fn download_part(tx: Sender<TaskResult>, url: String, idx: usize, pos: u64, length: u64) -> u64 {
}
主线程等待子线程的状态
主线程等待子线程的状态,如果所有的线程都完成,就退出循环。
let mut done_count = 0;
loop {
match rx.recv() {
Ok(TaskResult::Downloading(_idx, pos, data)) => {
downloaded += data.len() as u64;
outfile.seek(std::io::SeekFrom::Start(pos))?;
outfile.write_all(&data)?;
}
Ok(TaskResult::Failed(idx, e)) => {
println!("Thread {} failed: {}", idx, e);
return Err(e);
}
Ok(TaskResult::Done(_idx)) => {
done_count += 1;
if done_count == threads {
break;
}
}
}
}
总结
通过这个例子,我们可以学习到多线程的数据同步、文件操作等知识,同时我们也可以学习到命令行参数解析、HTTP请求等知识。
Rust以严谨的类型系统和内存安全著称,同时Rust的多线程编程也非常强大,通过Rust的多线程编程,我们可以实现高效的并发编程。
整个代码简洁明了,方便大家学习多线程的数据同步、文件操作等知识。
如果大家对后端编程有兴趣,可以关注入职啦,我们会定期更新后端编程的实战教程。
入职啦学习交流群