Signed-off-by: kjuulh <contact@kjuulh.io>
This commit is contained in:
1
crates/scel/.gitignore
vendored
Normal file
1
crates/scel/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/target
|
16
crates/scel/Cargo.toml
Normal file
16
crates/scel/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "scel"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.22", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
anyhow = { version = "1.0.66" }
|
||||
dotenv = { version = "*" }
|
||||
|
||||
scel_api = { path = "../scel_api" }
|
||||
scel_core = { path = "../scel_core" }
|
28
crates/scel/src/main.rs
Normal file
28
crates/scel/src/main.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use dotenv::dotenv;
|
||||
use scel_core::App;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{EnvFilter, FmtSubscriber};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
dotenv().ok();
|
||||
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_env_filter(
|
||||
EnvFilter::default()
|
||||
.add_directive("tower_http=debug".parse().unwrap())
|
||||
.add_directive("scel_api=info".parse().unwrap())
|
||||
.add_directive("scel=info".parse().unwrap()),
|
||||
)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber)?;
|
||||
|
||||
info!("Starting scel");
|
||||
|
||||
let app = Arc::new(App::new());
|
||||
|
||||
scel_api::Server::new(app.clone()).start().await
|
||||
}
|
33
crates/scel_api/Cargo.toml
Normal file
33
crates/scel_api/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "scel_api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.5.17" }
|
||||
axum-extra = { version = "0.3.7", features = ["spa"] }
|
||||
futures = "0.3.28"
|
||||
tower-http = { version = "0.3.4", features = ["cors", "trace"] }
|
||||
async-graphql = { version = "4.0.16", features = [
|
||||
'tracing',
|
||||
'opentelemetry',
|
||||
"log",
|
||||
] }
|
||||
async-graphql-axum = { version = "4.0.16" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0.89"
|
||||
tokio = { version = "1.22", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3" }
|
||||
anyhow = { version = "1.0.66" }
|
||||
oauth2 = { version = "*" }
|
||||
async-session = { version = "*" }
|
||||
reqwest = { version = "*", default-features = false, features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
] }
|
||||
hyper = { version = "*" }
|
||||
|
||||
scel_core = { path = "../scel_core" }
|
99
crates/scel_api/src/auth/mod.rs
Normal file
99
crates/scel_api/src/auth/mod.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use std::env;
|
||||
|
||||
use async_session::{MemoryStore, Session, SessionStore};
|
||||
use axum::{
|
||||
extract::Query,
|
||||
http::HeaderMap,
|
||||
response::{IntoResponse, Redirect},
|
||||
Extension,
|
||||
};
|
||||
use oauth2::{
|
||||
basic::BasicClient, reqwest::async_http_client, AuthUrl, AuthorizationCode, ClientId,
|
||||
ClientSecret, CsrfToken, RedirectUrl, TokenResponse, TokenUrl,
|
||||
};
|
||||
use reqwest::header::SET_COOKIE;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{User, COOKIE_NAME};
|
||||
|
||||
pub fn oauth_client() -> BasicClient {
|
||||
let client_id = env::var("GITEA_CLIENT_ID").expect("Missing GITEA_CLIENT_ID");
|
||||
let client_secret = env::var("GITEA_CLIENT_SECRET").expect("Missing GITEA_CLIENT_SECRET");
|
||||
let redirect_url = env::var("GITEA_REDIRECT_URL")
|
||||
.unwrap_or_else(|_| "http://127.0.0.1:3000/auth/authorized".to_string());
|
||||
|
||||
let auth_url =
|
||||
env::var("GITEA_AUTH_URL").unwrap_or_else(|_| "https://git.front.kjuulh.io".to_string());
|
||||
|
||||
let token_url =
|
||||
env::var("GITEA_TOKEN_URL").unwrap_or_else(|_| "https://git.front.kjuulh.io".to_string());
|
||||
|
||||
BasicClient::new(
|
||||
ClientId::new(client_id),
|
||||
Some(ClientSecret::new(client_secret)),
|
||||
AuthUrl::new(auth_url).expect("AuthUrl was invalid"),
|
||||
Some(TokenUrl::new(token_url).expect("Token url was invalid")),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("RedirectUrl was invalid"))
|
||||
}
|
||||
|
||||
pub async fn gitea(Extension(client): Extension<BasicClient>) -> impl IntoResponse {
|
||||
let (auth_url, _crsf_token) = client.authorize_url(CsrfToken::new_random).url();
|
||||
|
||||
Redirect::to(&auth_url.to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthRequest {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
pub async fn authorized(
|
||||
Query(query): Query<AuthRequest>,
|
||||
Extension(store): Extension<MemoryStore>,
|
||||
Extension(oauth_client): Extension<BasicClient>,
|
||||
) -> impl IntoResponse {
|
||||
let token = oauth_client
|
||||
.exchange_code(AuthorizationCode::new(query.code.clone()))
|
||||
.request_async(async_http_client)
|
||||
.await
|
||||
.expect("failed to get http client");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let user_data_json = client
|
||||
.get(get_gitea_user_data_url())
|
||||
.bearer_auth(token.access_token().secret())
|
||||
.send()
|
||||
.await
|
||||
.expect("Request did not succeed");
|
||||
// .text()
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
let user_data: User = user_data_json
|
||||
.json::<User>()
|
||||
.await
|
||||
.expect("could not parse user");
|
||||
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.insert("user", &user_data)
|
||||
.expect("could not insert user data");
|
||||
|
||||
let cookie = store
|
||||
.store_session(session)
|
||||
.await
|
||||
.expect("could not insert session")
|
||||
.expect("session was not valid");
|
||||
|
||||
let cookie = format!("{}={}; SameSite=Lax; Path=/", COOKIE_NAME, cookie);
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(SET_COOKIE, cookie.parse().expect("Cookie is not valid"));
|
||||
(headers, Redirect::to("/"))
|
||||
}
|
||||
|
||||
fn get_gitea_user_data_url() -> String {
|
||||
env::var("GITEA_USER_INFO_URL").expect("Missing GITEA_USER_INFO_URL")
|
||||
}
|
4
crates/scel_api/src/graphql/mod.rs
Normal file
4
crates/scel_api/src/graphql/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod mutation;
|
||||
pub mod query;
|
||||
pub mod schema;
|
||||
pub mod subscription;
|
38
crates/scel_api/src/graphql/mutation.rs
Normal file
38
crates/scel_api/src/graphql/mutation.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_graphql::{Context, Object, Result, SimpleObject, ID};
|
||||
use scel_core::{services::Download, App};
|
||||
|
||||
pub struct MutationRoot;
|
||||
|
||||
#[derive(SimpleObject)]
|
||||
struct RequestDownloadResponse {
|
||||
id: ID,
|
||||
}
|
||||
|
||||
#[Object]
|
||||
impl MutationRoot {
|
||||
async fn request_download(
|
||||
&self,
|
||||
ctx: &Context<'_>,
|
||||
download_link: String,
|
||||
) -> Result<RequestDownloadResponse> {
|
||||
let app = ctx.data_unchecked::<Arc<App>>();
|
||||
|
||||
let download = app
|
||||
.download_service
|
||||
.clone()
|
||||
.add_download(Download {
|
||||
id: None,
|
||||
link: download_link,
|
||||
progress: None,
|
||||
file_name: None,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Ok(RequestDownloadResponse {
|
||||
id: download.id.unwrap().into(),
|
||||
})
|
||||
}
|
||||
}
|
32
crates/scel_api/src/graphql/query.rs
Normal file
32
crates/scel_api/src/graphql/query.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_graphql::{Context, Object, Result, SimpleObject, ID};
|
||||
use scel_core::App;
|
||||
|
||||
#[derive(SimpleObject, Clone)]
|
||||
pub struct Download {
|
||||
pub id: ID,
|
||||
pub link: String,
|
||||
pub progress: Option<u32>,
|
||||
pub file_name: Option<String>,
|
||||
}
|
||||
|
||||
pub struct QueryRoot;
|
||||
|
||||
#[Object]
|
||||
impl QueryRoot {
|
||||
async fn get_download(&self, ctx: &Context<'_>, id: ID) -> Result<Option<Download>> {
|
||||
let app = ctx.data_unchecked::<Arc<App>>();
|
||||
|
||||
match app.download_service.get_download(id.to_string()).await {
|
||||
Ok(Some(d)) => Ok(Some(Download {
|
||||
id: ID::from(d.id.expect("ID could not be found")),
|
||||
progress: None,
|
||||
link: d.link,
|
||||
file_name: None,
|
||||
})),
|
||||
Ok(None) => Ok(None),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
5
crates/scel_api/src/graphql/schema.rs
Normal file
5
crates/scel_api/src/graphql/schema.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
use async_graphql::Schema;
|
||||
|
||||
use super::{mutation::MutationRoot, query::QueryRoot, subscription::SubscriptionRoot};
|
||||
|
||||
pub type ScelSchema = Schema<QueryRoot, MutationRoot, SubscriptionRoot>;
|
49
crates/scel_api/src/graphql/subscription.rs
Normal file
49
crates/scel_api/src/graphql/subscription.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_graphql::{
|
||||
async_stream::stream, futures_util::Stream, Context, Object, Subscription, ID,
|
||||
};
|
||||
use scel_core::App;
|
||||
|
||||
use super::query::Download;
|
||||
|
||||
pub struct SubscriptionRoot;
|
||||
|
||||
struct DownloadChanged {
|
||||
download: Download,
|
||||
}
|
||||
|
||||
#[Object]
|
||||
impl DownloadChanged {
|
||||
async fn download(&self) -> Download {
|
||||
self.download.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[Subscription]
|
||||
impl SubscriptionRoot {
|
||||
async fn get_download(&self, ctx: &Context<'_>, id: ID) -> impl Stream<Item = DownloadChanged> {
|
||||
let app = ctx.data_unchecked::<Arc<App>>();
|
||||
|
||||
let mut stream = app
|
||||
.download_service
|
||||
.subscribe_download(id.to_string())
|
||||
.await;
|
||||
|
||||
stream! {
|
||||
while stream.changed().await.is_ok() {
|
||||
let next_download = (*stream.borrow()).clone();
|
||||
let id = ID::from(next_download.id.unwrap());
|
||||
|
||||
yield DownloadChanged {
|
||||
download: Download {
|
||||
id: id,
|
||||
link: next_download.link,
|
||||
file_name: next_download.file_name,
|
||||
progress: next_download.progress,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
173
crates/scel_api/src/lib.rs
Normal file
173
crates/scel_api/src/lib.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
mod auth;
|
||||
mod graphql;
|
||||
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
|
||||
use async_graphql::{
|
||||
extensions::{Logger, Tracing},
|
||||
http::{playground_source, GraphQLPlaygroundConfig},
|
||||
Request, Response, Schema,
|
||||
};
|
||||
use async_graphql_axum::GraphQLSubscription;
|
||||
use async_session::{async_trait, MemoryStore, SessionStore};
|
||||
use auth::{authorized, gitea};
|
||||
use axum::{
|
||||
extract::{rejection::TypedHeaderRejectionReason, FromRequest, RequestParts},
|
||||
headers,
|
||||
http::{header, Method},
|
||||
response::{Html, IntoResponse, Redirect},
|
||||
routing::{self, get_service},
|
||||
Extension, Json, Router, TypedHeader,
|
||||
};
|
||||
use graphql::{
|
||||
mutation::MutationRoot, query::QueryRoot, schema::ScelSchema, subscription::SubscriptionRoot,
|
||||
};
|
||||
use reqwest::StatusCode;
|
||||
use scel_core::App;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_http::{
|
||||
cors::CorsLayer,
|
||||
services::ServeDir,
|
||||
trace::{DefaultMakeSpan, TraceLayer},
|
||||
};
|
||||
|
||||
async fn graphql_playground() -> impl IntoResponse {
|
||||
Html(playground_source(
|
||||
GraphQLPlaygroundConfig::new("/graphql").subscription_endpoint("/ws"),
|
||||
))
|
||||
}
|
||||
async fn graphql_handler(
|
||||
schema: Extension<ScelSchema>,
|
||||
req: Json<Request>,
|
||||
_: User,
|
||||
) -> Json<Response> {
|
||||
schema.execute(req.0).await.into()
|
||||
}
|
||||
|
||||
pub struct Server {
|
||||
app: Router,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn new(app: Arc<App>) -> Server {
|
||||
let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
|
||||
.extension(Tracing)
|
||||
.extension(Logger)
|
||||
.data(app)
|
||||
.finish();
|
||||
|
||||
let cors = vec![
|
||||
"http://localhost:3000"
|
||||
.parse()
|
||||
.expect("Could not parse url"),
|
||||
"https://scel.front.kjuulh.io"
|
||||
.parse()
|
||||
.expect("Could not parse url"),
|
||||
];
|
||||
|
||||
let api_router = Router::new()
|
||||
.route(
|
||||
"/graphql",
|
||||
routing::get(graphql_playground).post(graphql_handler),
|
||||
)
|
||||
.route("/ws", GraphQLSubscription::new(schema.clone()))
|
||||
.route("/auth/gitea", routing::get(gitea))
|
||||
.route("/auth/authorized", routing::get(authorized))
|
||||
// .merge(axum_extra::routing::SpaRouter::new(
|
||||
// "/assets",
|
||||
// "src/web/dist/assets",
|
||||
// ))
|
||||
.fallback(get_service(ServeDir::new("./src/web/dist/")).handle_error(handle_error))
|
||||
.layer(Extension(schema))
|
||||
.layer(Extension(MemoryStore::new()))
|
||||
.layer(Extension(auth::oauth_client()))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors)
|
||||
.allow_headers([axum::http::header::CONTENT_TYPE])
|
||||
.allow_methods([Method::GET, Method::POST, Method::OPTIONS]),
|
||||
)
|
||||
.layer(TraceLayer::new_for_http().make_span_with(DefaultMakeSpan::default()));
|
||||
|
||||
let app = Router::new().nest("/api", api_router);
|
||||
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
|
||||
|
||||
Server { app, addr }
|
||||
}
|
||||
|
||||
pub async fn start(self) -> anyhow::Result<()> {
|
||||
tracing::info!("listening on {}", self.addr);
|
||||
|
||||
match axum::Server::bind(&self.addr)
|
||||
.serve(self.app.into_make_service())
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct User {
|
||||
#[serde(alias = "sub")]
|
||||
id: String,
|
||||
#[serde(alias = "picture")]
|
||||
avatar: Option<String>,
|
||||
#[serde(alias = "email")]
|
||||
email: String,
|
||||
#[serde(alias = "preferred_username")]
|
||||
username: String,
|
||||
}
|
||||
|
||||
struct AuthRedirect;
|
||||
|
||||
impl IntoResponse for AuthRedirect {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
Redirect::temporary("/auth/gitea").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
const COOKIE_NAME: &str = "auth";
|
||||
|
||||
#[async_trait]
|
||||
impl<B> FromRequest<B> for User
|
||||
where
|
||||
B: Send,
|
||||
{
|
||||
type Rejection = AuthRedirect;
|
||||
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let Extension(store) = Extension::<MemoryStore>::from_request(req)
|
||||
.await
|
||||
.expect("MemoryStore extension is missing");
|
||||
|
||||
let cookies = TypedHeader::<headers::Cookie>::from_request(req)
|
||||
.await
|
||||
.map_err(|e| match *e.name() {
|
||||
header::COOKIE => match e.reason() {
|
||||
TypedHeaderRejectionReason::Missing => AuthRedirect,
|
||||
_ => panic!("unexpected error getting Cookie header(s): {}", e),
|
||||
},
|
||||
_ => panic!("unexpected error getting cookies: {}", e),
|
||||
})?;
|
||||
|
||||
let session_cookie = cookies.get(COOKIE_NAME).ok_or(AuthRedirect)?;
|
||||
|
||||
let session = store
|
||||
.load_session(session_cookie.to_string())
|
||||
.await
|
||||
.expect("could not load session")
|
||||
.ok_or(AuthRedirect)?;
|
||||
|
||||
let user = session.get::<User>("user").ok_or(AuthRedirect)?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_error(_err: io::Error) -> impl IntoResponse {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong...")
|
||||
}
|
2
crates/scel_core/.gitignore
vendored
Normal file
2
crates/scel_core/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
/target
|
||||
/Cargo.lock
|
17
crates/scel_core/Cargo.toml
Normal file
17
crates/scel_core/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "scel_core"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.22", features = ["full"] }
|
||||
anyhow = { version = "*" }
|
||||
async-trait = { version = "0.1.58" }
|
||||
futures = "0.3.28"
|
||||
tracing = "0.1"
|
||||
lazy_static = "1.4.0"
|
||||
regex = { version = "1.7.0" }
|
||||
thiserror = "1.0.37"
|
||||
uuid = {version = "1.2.2", features = ["v4", "fast-rng"]}
|
19
crates/scel_core/src/lib.rs
Normal file
19
crates/scel_core/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use services::InMemoryDownloadService;
|
||||
|
||||
pub mod services;
|
||||
mod youtube;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct App {
|
||||
pub download_service: Arc<InMemoryDownloadService>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
download_service: Arc::new(InMemoryDownloadService::new()),
|
||||
}
|
||||
}
|
||||
}
|
3
crates/scel_core/src/repo/users_repo.rs
Normal file
3
crates/scel_core/src/repo/users_repo.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub trait UsersRepo {
|
||||
// add code here
|
||||
}
|
128
crates/scel_core/src/services/mod.rs
Normal file
128
crates/scel_core/src/services/mod.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use std::{collections::HashMap, path::PathBuf, sync::Arc};
|
||||
use tokio::sync::{watch, Mutex};
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::youtube::{Arg, YoutubeDL};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Download {
|
||||
pub id: Option<String>,
|
||||
pub link: String,
|
||||
pub progress: Option<u32>,
|
||||
pub file_name: Option<String>,
|
||||
}
|
||||
|
||||
pub struct InMemoryDownloadService {
|
||||
downloads: Mutex<
|
||||
HashMap<
|
||||
String,
|
||||
(
|
||||
Arc<Mutex<Download>>,
|
||||
Arc<Mutex<tokio::sync::watch::Sender<Download>>>,
|
||||
tokio::sync::watch::Receiver<Download>,
|
||||
),
|
||||
>,
|
||||
>,
|
||||
}
|
||||
|
||||
impl InMemoryDownloadService {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
downloads: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_download(self: Arc<Self>, download: Download) -> anyhow::Result<Download> {
|
||||
let mut downloads = self.downloads.lock().await;
|
||||
|
||||
let (tx, rx) = watch::channel(download.clone());
|
||||
let shared_tx = Arc::new(Mutex::new(tx));
|
||||
|
||||
let mut d = download.to_owned();
|
||||
|
||||
let id = Uuid::new_v4().to_string();
|
||||
d.id = Some(id.clone());
|
||||
|
||||
downloads.insert(id.clone(), (Arc::new(Mutex::new(d.clone())), shared_tx, rx));
|
||||
|
||||
let args = vec![
|
||||
Arg::new("--progress"),
|
||||
Arg::new("--newline"),
|
||||
Arg::new_with_args("--output", "%(title).90s.%(ext)s"),
|
||||
];
|
||||
let ytd = YoutubeDL::new(
|
||||
&PathBuf::from("./data/downloads"),
|
||||
args,
|
||||
download.link.as_str(),
|
||||
)?;
|
||||
|
||||
tokio::spawn({
|
||||
let download_service = self.clone();
|
||||
|
||||
async move {
|
||||
if let Err(e) = ytd
|
||||
.download(
|
||||
|percentage| {
|
||||
let ds = download_service.clone();
|
||||
let id = id.clone();
|
||||
|
||||
async move {
|
||||
let mut download = ds.get_download(id).await.unwrap().unwrap();
|
||||
download.progress = Some(percentage);
|
||||
let _ = ds.update_download(download).await;
|
||||
}
|
||||
},
|
||||
|file_name| {
|
||||
let ds = download_service.clone();
|
||||
let id = id.clone();
|
||||
|
||||
async move {
|
||||
let mut download = ds.get_download(id).await.unwrap().unwrap();
|
||||
download.file_name = Some(file_name);
|
||||
let _ = ds.update_download(download).await;
|
||||
}
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("Download failed: {}", e);
|
||||
} else {
|
||||
let download = download_service.get_download(id).await.unwrap().unwrap();
|
||||
let _ = download_service.update_download(download).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(d)
|
||||
}
|
||||
|
||||
pub async fn update_download(self: Arc<Self>, download: Download) -> anyhow::Result<()> {
|
||||
let mut downloads = self.downloads.lock().await;
|
||||
if let Some(d) = downloads.get_mut(&download.clone().id.unwrap()) {
|
||||
let mut d_mut = d.0.lock().await;
|
||||
*d_mut = download.clone();
|
||||
let _ = d.1.lock().await.send(download);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_download(&self, id: String) -> anyhow::Result<Option<Download>> {
|
||||
let downloads = self.downloads.lock().await;
|
||||
|
||||
if let Some(d) = downloads.get(&id) {
|
||||
let download = d.0.lock().await;
|
||||
|
||||
Ok(Some(download.clone()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn subscribe_download(&self, id: String) -> tokio::sync::watch::Receiver<Download> {
|
||||
let downloads = self.downloads.lock().await;
|
||||
let download = downloads.get(&id).unwrap();
|
||||
download.2.clone()
|
||||
}
|
||||
}
|
256
crates/scel_core/src/youtube/mod.rs
Normal file
256
crates/scel_core/src/youtube/mod.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::fs::{canonicalize, create_dir_all};
|
||||
use std::future::Future;
|
||||
use std::num::ParseIntError;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Output, Stdio};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum YoutubeDLError {
|
||||
#[error("failed to execute youtube-dl")]
|
||||
IOError(#[from] std::io::Error),
|
||||
#[error("failed to convert path")]
|
||||
UTF8Error(#[from] std::string::FromUtf8Error),
|
||||
#[error("youtube-dl exited with: {0}")]
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, YoutubeDLError>;
|
||||
|
||||
const YOUTUBE_DL_COMMAND: &str = "yt-dlp";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Arg {
|
||||
arg: String,
|
||||
input: Option<String>,
|
||||
}
|
||||
|
||||
impl Arg {
|
||||
pub fn new(argument: &str) -> Self {
|
||||
Self {
|
||||
arg: argument.to_string(),
|
||||
input: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_args(argument: &str, input: &str) -> Self {
|
||||
Self {
|
||||
arg: argument.to_string(),
|
||||
input: Option::from(input.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Arg {
|
||||
fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.input {
|
||||
Some(input) => write!(fmt, "{} {}", self.arg, input),
|
||||
None => write!(fmt, "{}", self.arg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct YoutubeDL {
|
||||
path: PathBuf,
|
||||
links: Vec<String>,
|
||||
args: Vec<Arg>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct YoutubeDLResult {
|
||||
path: PathBuf,
|
||||
output: String,
|
||||
}
|
||||
|
||||
impl YoutubeDLResult {
|
||||
fn new(path: &PathBuf) -> Self {
|
||||
Self {
|
||||
path: path.clone(),
|
||||
output: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn output_dir(&self) -> &PathBuf {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl YoutubeDL {
|
||||
pub fn new_multiple_links(
|
||||
dl_path: &PathBuf,
|
||||
args: Vec<Arg>,
|
||||
links: Vec<String>,
|
||||
) -> Result<YoutubeDL> {
|
||||
let path = Path::new(dl_path);
|
||||
|
||||
if !path.exists() {
|
||||
create_dir_all(&path)?;
|
||||
}
|
||||
|
||||
if !path.is_dir() {
|
||||
return Err(YoutubeDLError::IOError(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"path is not a directory",
|
||||
)));
|
||||
}
|
||||
|
||||
let path = canonicalize(dl_path)?;
|
||||
Ok(YoutubeDL { path, links, args })
|
||||
}
|
||||
|
||||
pub fn new(dl_path: &PathBuf, args: Vec<Arg>, link: &str) -> Result<YoutubeDL> {
|
||||
YoutubeDL::new_multiple_links(dl_path, args, vec![link.to_string()])
|
||||
}
|
||||
|
||||
pub async fn download<F, FutAvailable, FAvailable, Fut>(
|
||||
&self,
|
||||
progress_update_fn: F,
|
||||
file_name_available: FAvailable,
|
||||
) -> Result<YoutubeDLResult>
|
||||
where
|
||||
F: Fn(u32) -> Fut,
|
||||
FAvailable: Fn(String) -> FutAvailable,
|
||||
Fut: Future<Output = ()>,
|
||||
FutAvailable: Future<Output = ()>,
|
||||
{
|
||||
let output = self
|
||||
.spawn_youtube_dl(progress_update_fn, file_name_available)
|
||||
.await?;
|
||||
let mut result = YoutubeDLResult::new(&self.path);
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(YoutubeDLError::Failure(String::from_utf8(output.stderr)?));
|
||||
}
|
||||
result.output = String::from_utf8(output.stdout)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn spawn_youtube_dl<F, FutAvailable, FAvailable, Fut>(
|
||||
&self,
|
||||
progress_update_fn: F,
|
||||
file_name_available: FAvailable,
|
||||
) -> Result<Output>
|
||||
where
|
||||
F: Fn(u32) -> Fut,
|
||||
FAvailable: Fn(String) -> FutAvailable,
|
||||
Fut: Future<Output = ()>,
|
||||
FutAvailable: Future<Output = ()>,
|
||||
{
|
||||
let mut cmd = Command::new(YOUTUBE_DL_COMMAND);
|
||||
cmd.current_dir(&self.path)
|
||||
.env("LC_ALL", "en_US.UTF-8")
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped());
|
||||
|
||||
for arg in self.args.iter() {
|
||||
match &arg.input {
|
||||
Some(input) => cmd.arg(&arg.arg).arg(input),
|
||||
None => cmd.arg(&arg.arg),
|
||||
};
|
||||
}
|
||||
|
||||
for link in self.links.iter() {
|
||||
cmd.arg(&link);
|
||||
}
|
||||
|
||||
let mut pr = cmd.spawn()?;
|
||||
|
||||
{
|
||||
let stdout = pr.stdout.as_mut().unwrap();
|
||||
let stdout_reader = BufReader::new(stdout);
|
||||
let mut stdout_lines = stdout_reader.lines();
|
||||
|
||||
let mut have_gotten_file_name = false;
|
||||
while let Ok(Some(line)) = stdout_lines.next_line().await {
|
||||
println!("{}", line.clone());
|
||||
|
||||
if !have_gotten_file_name {
|
||||
if let Some(file_name) = parse_file_name(line.clone()) {
|
||||
file_name_available(file_name).await;
|
||||
have_gotten_file_name = true
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(Ok(percentage)) = parse_line(line) {
|
||||
progress_update_fn(percentage).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(pr.wait_with_output().await?)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_line(line: String) -> Option<core::result::Result<u32, ParseIntError>> {
|
||||
lazy_static! {
|
||||
static ref RE: Regex = Regex::new(r"\[download\]\s+(\d+)").unwrap();
|
||||
}
|
||||
|
||||
let capture: regex::Captures = RE.captures(line.as_str())?;
|
||||
if capture.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let str = &capture[1];
|
||||
Some(str.to_string().parse::<u32>())
|
||||
}
|
||||
|
||||
fn parse_file_name(line: String) -> Option<String> {
|
||||
lazy_static! {
|
||||
static ref RE: Regex = Regex::new(r"^\[download\] Destination: (.+)$").unwrap();
|
||||
}
|
||||
|
||||
let capture: regex::Captures = RE.captures(line.as_str())?;
|
||||
if capture.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
let str = &capture[1];
|
||||
Some(str.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::youtube::{parse_file_name, parse_line};
|
||||
|
||||
#[test]
|
||||
fn test_parse_line() {
|
||||
let percentage = parse_line(
|
||||
"[download] 95.4% of ~215.85MiB at 9.61MiB/s ETA 00:01 (frag 144/151)".into(),
|
||||
);
|
||||
|
||||
assert_eq!(percentage, Some(Ok(95)))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_get_nothing() {
|
||||
let nothing = parse_line("[download] Got server HTTP error: The read operation timed out. Retrying (attempt 1 of 10) ...".into());
|
||||
|
||||
assert_eq!(nothing, None)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_name() {
|
||||
let file_name = parse_file_name(
|
||||
"[download] Destination: 10 Design Patterns Explained in 10 Minutes.mp4".into(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
file_name,
|
||||
Some("10 Design Patterns Explained in 10 Minutes.mp4".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_name_get_nothing() {
|
||||
let nothing = parse_file_name("[download] No fit: something".into());
|
||||
|
||||
assert_eq!(nothing, None)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user