项目地址

如何实现多线程下载

直接上代码

代码解析

如何初始化项目

命令行参数解析

实现文件名的处理

多线程下载

多线程之间的通信

子线程执行下载任务

主线程等待子线程的状态

总结

用rust实现多线程的HTTP下载器,不需要tokio

本文发表于入职啦(公众号: ruzhila) 大家可以访问入职啦学习更多的编程实战。

项目地址

代码已经开源, mget_rust 👏 Star

代码运行效果:

mget_rs

如何实现多线程下载

HTTP通过在HEAD添加Range头,可以实现分片下载,这样就可以实现多线程下载。

  • 在下载之前先通过HEAD请求获取文件的大小,然后根据文件的大小,分配线程的下载区间,然后每个线程下载对应的区间,就可以实现多线程下载

直接上代码

code

代码解析

如何初始化项目

我们通过cargo初始化项目,并且我们会引入一些依赖,比如reqwest用于HTTP请求,clap用于命令行参数解析。

cargo new mget_rs
cd mget_rs
cargo add reqwest
cargo add clap

同时要修改Cargo.toml文件,默认的reqwest是基于tokio的,我们需要修改为:

[dependencies]
clap = { version = "4.4.17", features = ["derive"] }
reqwest = { version = "0.12.5", features = ["blocking"] }

命令行参数解析

clap是非常好用的命令行参数解析库,我们可以通过clap解析命令行参数。

#[derive(Parser, Debug)]
#[command(version)]
struct Cli {
    #[clap(long, short, default_value = "2")]
    threads: usize,

    #[clap(long, short)]
    output: Option<String>,

    #[clap(long, short, default_value = "false")]
    verbose: bool,

    url: String,
}

这样我们就可以实现命令行参数的解析:

  • -t/--thread 用于指定线程数
  • -o/--output 用于指定输出文件
  • -v/--verbose 用于打印详细信息

实现文件名的处理

如果url上没有文件名,我们需要通过url解析出文件名,默认是index.html:

    let file_name = match output {
        Some(name) => name.to_string(),
        None => {
            let url = Url::parse(url).map_err(|e| Error::new(ErrorKind::InvalidInput, e))?;
            url.path_segments()
                .and_then(|segments| segments.last())
                .and_then(|name| if name.is_empty() { None } else { Some(name) })
                .unwrap_or("index.html")
                .to_string()
        }
    };

并且如果要创建文件,不能直接覆盖,我们需要判断文件是否存在,我们可以创建新的文件名:

    // try rename the file to avoid conflict
    let mut file_name = file_name;
    let mut index = 1;
    while std::fs::metadata(&file_name).is_ok() {
        let parts: Vec<&str> = file_name.rsplitn(2, '.').collect();
        if parts.len() == 2 {
            file_name = format!("{}.{}.{}", parts[1], index, parts[0]);
        } else {
            file_name = format!("{}.{}", file_name, index);
        }
        index += 1;
    }

多线程下载

这部分会比较麻烦,我们需要实现至少1个下载线程去执行下载任务,然后主线程等待所有的下载线程完成。

这时候,我们就需要引入一个mpsc的channel,用于线程之间的通信。 mpsc 的意思就是多个生产者,单个消费者,主线程通过等子线程的状态,来判断是否所有的线程都已经完成或者更新任务:

// 我们先定义好线程之间的数据格式
enum TaskResult {
    // 下载中
    Downloading(usize, u64, Box<[u8]>),  // 线程id,下载的大小,下载的数据
    // 下载失败
    Failed(usize, Error), // 线程id,错误信息
    // 下载完成
    Done(usize), // 线程id 
}

多线程之间的通信

mpsc需要用TaskResult进行通信,我们看看多线程之间的通信:

//主线程启动多个下载线程
let (tx, rx) = std::sync::mpsc::channel::<TaskResult>();
for idx in 0..threads {
   // 实现分片分配
   let pos = idx as u64 * file_size / threads as u64;
   let length = if idx == threads - 1 {
      file_size - pos
   } else {
      file_size / threads as u64
   };
   let url = url.to_string();
   let tx = tx.clone();
   spawn(move || download_part(tx.clone(), url, idx, pos, length));
}

子线程执行下载任务

实现下载逻辑,并且将结果通过tx发送给主线程

fn download_part(tx: Sender<TaskResult>, url: String, idx: usize, pos: u64, length: u64) -> u64 {
   
}

主线程等待子线程的状态

主线程等待子线程的状态,如果所有的线程都完成,就退出循环。

let mut done_count = 0;
loop {
        match rx.recv() {
            Ok(TaskResult::Downloading(_idx, pos, data)) => {
                downloaded += data.len() as u64;
                outfile.seek(std::io::SeekFrom::Start(pos))?;
                outfile.write_all(&data)?;
            }
            Ok(TaskResult::Failed(idx, e)) => {
                println!("Thread {} failed: {}", idx, e);
                return Err(e);
            }
            Ok(TaskResult::Done(_idx)) => {
                done_count += 1;
                if done_count == threads {
                    break;
                }
            }
        }
    }

总结

通过这个例子,我们可以学习到多线程的数据同步、文件操作等知识,同时我们也可以学习到命令行参数解析、HTTP请求等知识。

Rust以严谨的类型系统和内存安全著称,同时Rust的多线程编程也非常强大,通过Rust的多线程编程,我们可以实现高效的并发编程。

整个代码简洁明了,方便大家学习多线程的数据同步、文件操作等知识。

如果大家对后端编程有兴趣,可以关注入职啦,我们会定期更新后端编程的实战教程。

入职啦学习交流群

入群学习

友情链接:

Copyright© 2024 杭州园中葵科技有限公司 版权所有