From 7c1cc34ff03fef10a9bb2f3f86af5318cdb6a07a Mon Sep 17 00:00:00 2001 From: Guus van Meerveld Date: Tue, 23 Apr 2024 21:17:50 +0200 Subject: [PATCH] implemented storing of credentials and tokens in keyring --- Cargo.lock | 138 ++++++++++++++++++++ sshn-cli/Cargo.toml | 3 + sshn-cli/src/{commands/login.rs => auth.rs} | 70 ++++++++-- sshn-cli/src/commands/mod.rs | 52 +++++++- sshn-cli/src/error.rs | 12 ++ sshn-cli/src/main.rs | 63 +++++++-- sshn-cli/src/publication.rs | 61 +++++++++ sshn-cli/src/secrets.rs | 82 ++++++++++++ sshn-lib/Cargo.toml | 2 +- sshn-lib/src/client.rs | 111 +++++++--------- sshn-lib/src/constants.rs | 2 +- sshn-lib/src/error.rs | 2 + sshn-lib/src/lib.rs | 8 +- sshn-lib/src/tokens.rs | 56 +++++++- 14 files changed, 561 insertions(+), 101 deletions(-) rename sshn-cli/src/{commands/login.rs => auth.rs} (61%) create mode 100644 sshn-cli/src/publication.rs create mode 100644 sshn-cli/src/secrets.rs diff --git a/Cargo.lock b/Cargo.lock index 72c8469..1b9acd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -409,6 +409,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-targets 0.52.5", ] @@ -542,6 +543,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "deranged" version = "0.3.11" @@ -573,6 +595,27 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "dotenv" version = "0.15.0" @@ -585,6 +628,12 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.34" @@ -1256,6 +1305,17 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +[[package]] +name = "is-terminal" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1297,6 +1357,16 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.5.0", + "libc", +] + [[package]] name = "linux-keyutils" version = "0.2.4" @@ -1708,6 +1778,20 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "prettytable" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46480520d1b77c9a3482d39939fcf96831537a250ec62d4fd8fbdf8e0302e781" +dependencies = [ + "csv", + "encode_unicode", + "is-terminal", + "lazy_static", + "term", + "unicode-width", +] + [[package]] name = "proc-macro-crate" version = "1.3.1" @@ -1775,6 +1859,17 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_users" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.10.4" @@ -1916,6 +2011,12 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" +[[package]] +name = "rustversion" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" + [[package]] name = "ryu" version = "1.0.17" @@ -2108,11 +2209,14 @@ dependencies = [ "fantoccini", "keyring", "log", + "prettytable", "rpassword", "serde", + "serde_json", "sshn-lib", "thiserror", "tokio", + "whoami", ] [[package]] @@ -2214,6 +2318,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "term" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c59df8ac95d96ff9bede18eb7300b0fda5e5d8d90960e76f8e14ae765eedbf1f" +dependencies = [ + "dirs-next", + "rustversion", + "winapi", +] + [[package]] name = "thiserror" version = "1.0.59" @@ -2461,6 +2576,12 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "unreachable" version = "1.0.0" @@ -2526,6 +2647,12 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -2621,6 +2748,17 @@ dependencies = [ "url", ] +[[package]] +name = "whoami" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44ab49fad634e88f55bf8f9bb3abd2f27d7204172a112c7c9987e01c1c94ea9" +dependencies = [ + "redox_syscall", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/sshn-cli/Cargo.toml b/sshn-cli/Cargo.toml index ac951b9..fcb2ecb 100644 --- a/sshn-cli/Cargo.toml +++ b/sshn-cli/Cargo.toml @@ -19,3 +19,6 @@ thiserror = "1.0.59" rpassword = "7.3.1" serde = { version = "1.0.198", features = ["derive"] } keyring = "2.3.2" +whoami = "1.5.1" +serde_json = "1.0.116" +prettytable = "0.10.0" diff --git a/sshn-cli/src/commands/login.rs b/sshn-cli/src/auth.rs similarity index 61% rename from sshn-cli/src/commands/login.rs rename to sshn-cli/src/auth.rs index 00e385a..7c3d7de 100644 --- a/sshn-cli/src/commands/login.rs +++ b/sshn-cli/src/auth.rs @@ -1,20 +1,60 @@ use std::time::Duration; use fantoccini::{ClientBuilder, Locator}; -use sshn_lib::{generate_auth_url, get_code_challenge, LoginType}; +use sshn_lib::{generate_auth_url, get_code_challenge, AuthenticatedClient, LoginType}; use crate::{ error::{Error, Result}, + secrets::{self, Credentials}, WebDriver, }; const LOGIN_FORM_ID: &str = "kc-form-login"; +#[derive(Debug)] +pub struct AuthOptions { + webdriver: WebDriver, + webdriver_port: u16, + + login_base_url: Option, +} + +impl AuthOptions { + pub fn webdriver(self, webdriver: WebDriver) -> Self { + Self { webdriver, ..self } + } + + pub fn webdriver_port(self, webdriver_port: u16) -> Self { + Self { + webdriver_port, + ..self + } + } + + pub fn login_base_url>(self, login_base_url: L) -> Self { + Self { + login_base_url: Some(login_base_url.into()), + ..self + } + } +} + +impl Default for AuthOptions { + fn default() -> Self { + Self { + login_base_url: None, + webdriver: WebDriver::Chromium, + webdriver_port: 4444, + } + } +} + /// Starts the given webdriver on the given port, then waits until said driver has started up. async fn start_web_driver(webdriver: WebDriver, port: u16) -> Result { let process = match webdriver { WebDriver::Chromium => tokio::process::Command::new("chromedriver") .arg(format!("--port={}", port)) + .arg("--headless") .spawn()?, WebDriver::Gecko => tokio::process::Command::new("geckodriver") .arg("--port") @@ -49,37 +89,34 @@ async fn start_web_driver(webdriver: WebDriver, port: u16) -> Result, P: AsRef, L: Into>( +pub async fn password_login, P: AsRef>( username: U, password: P, - webdriver: WebDriver, - login_url: Option, -) -> Result<()> { + options: AuthOptions, +) -> Result { let client = sshn_lib::Client::new(None); let (code_challenge, code_verifier) = get_code_challenge(); - let login_url: String = match login_url { + let login_base_url: String = match options.login_base_url.as_ref() { Some(url) => url.into(), None => { let endpoints = client.get_endpoints().await?; - let base_login_url = endpoints + endpoints .identity_config .ok_or(Error::MissingLoginUrl)? .authorization_endpoint - .ok_or(Error::MissingLoginUrl)?; - - generate_auth_url(base_login_url, code_challenge)? + .ok_or(Error::MissingLoginUrl)? } }; - let port = 4444; + let login_url = generate_auth_url(login_base_url, code_challenge)?; - let mut driver = start_web_driver(webdriver, port).await?; + let mut driver = start_web_driver(options.webdriver, options.webdriver_port).await?; let browser = ClientBuilder::native() - .connect(&format!("http://localhost:{}", port)) + .connect(&format!("http://localhost:{}", options.webdriver_port)) .await?; log::info!("Logging into SSHN at {}", login_url); @@ -114,5 +151,10 @@ pub async fn login, P: AsRef, L: Into>( }) .await?; - Ok(()) + let credentials = Credentials::new(username.as_ref(), password.as_ref()); + + secrets::set("credentials", &credentials)?; + secrets::set("tokens", auth_client.tokens())?; + + Ok(auth_client) } diff --git a/sshn-cli/src/commands/mod.rs b/sshn-cli/src/commands/mod.rs index 320cbbb..6661b13 100644 --- a/sshn-cli/src/commands/mod.rs +++ b/sshn-cli/src/commands/mod.rs @@ -1 +1,51 @@ -pub mod login; +use crate::{ + auth::{self, AuthOptions}, + error::Result, + publication::{self, Publication}, + secrets, +}; + +pub async fn login, P: AsRef>( + username: U, + password: P, + options: AuthOptions, +) -> Result<()> { + auth::password_login(username.as_ref(), password.as_ref(), options).await?; + + Ok(()) +} + +pub async fn list(limit: usize) -> Result<()> { + let data = publication::list_publications(limit).await?; + + let mut table = prettytable::Table::new(); + + table.add_row(prettytable::Row::new( + Publication::row_labels() + .iter() + .map(|label| prettytable::Cell::new(label)) + .collect(), + )); + + for publication in data { + table.add_row(prettytable::Row::new( + publication + .as_row() + .iter() + .map(|label| prettytable::Cell::new(label)) + .collect(), + )); + } + + table.printstd(); + + Ok(()) +} + +pub async fn reply>(id: I) -> Result<()> { + let mut client = secrets::get_client().await?; + + client.reply_to_publication(id.as_ref()).await?; + + Ok(()) +} diff --git a/sshn-cli/src/error.rs b/sshn-cli/src/error.rs index 0a1a56e..0593684 100644 --- a/sshn-cli/src/error.rs +++ b/sshn-cli/src/error.rs @@ -11,6 +11,12 @@ pub enum Error { #[error("SSHN Api did not return valid authorization code")] MissingAuthCode, + #[error("SSHN Api did not return valid publications")] + MissingPublications, + + #[error("Missing username and password credentials")] + MissingCredentials, + #[error("Failed to start web driver")] WebDriverStart, @@ -20,6 +26,12 @@ pub enum Error { #[error("Failed communicating with browser: {0}")] HeadlessBrowser(#[from] fantoccini::error::CmdError), + #[error("Keyring error: {0}")] + Keyring(#[from] keyring::Error), + + #[error("Failed to serialize/deserialize JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("IO error: {0}")] Io(#[from] std::io::Error), } diff --git a/sshn-cli/src/main.rs b/sshn-cli/src/main.rs index dce8f08..d333ed1 100644 --- a/sshn-cli/src/main.rs +++ b/sshn-cli/src/main.rs @@ -2,12 +2,15 @@ use clap::{Parser, Subcommand}; use rpassword::prompt_password; use serde::Serialize; -use crate::commands::login::login; - +mod auth; mod commands; mod error; +mod publication; +mod secrets; + +use auth::AuthOptions; -/// Simple program to greet a person +/// SSHN command line interface. #[derive(Parser, Debug)] #[command(version, about, long_about = None)] #[command(propagate_version = true)] @@ -18,22 +21,33 @@ struct Args { #[derive(Subcommand, Debug)] pub enum Commands { + /// Login to the SSHN API. Login { - /// Username of the SSHN account + /// Username of the SSHN account. #[arg(short, long)] username: String, - /// Password of the SSHN account + /// Password of the SSHN account. #[arg(short, long)] password: Option, - /// The login url + /// The login portal base url. #[arg(short, long)] login_url: Option, + /// The web driver to use to connect to the browser. #[arg(short, long, default_value_t, value_enum)] webdriver: WebDriver, }, + + /// List the currently open publications. + List { + #[arg(short, long)] + limit: Option, + }, + + /// Reply to a publication with a given id. + Reply { id: String }, } #[derive(clap::ValueEnum, Serialize, Debug, Clone, Default)] @@ -46,7 +60,14 @@ pub enum WebDriver { #[tokio::main] async fn main() { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + { + let mut builder = env_logger::Builder::from_default_env(); + + builder.filter_module("sshn_cli", log::LevelFilter::Info); + builder.filter_module("sshn_lib", log::LevelFilter::Info); + + builder.init(); + } let args = Args::parse(); @@ -69,9 +90,13 @@ async fn main() { log::info!("Logging in as user '{}'", username); - let login_result = login(&username, &password, webdriver, login_url).await; - - match login_result { + match commands::login( + &username, + &password, + AuthOptions::default().webdriver(webdriver), + ) + .await + { Ok(_) => { log::info!("Succesfully logged in as user '{}'", username) } @@ -80,5 +105,23 @@ async fn main() { } } } + + Commands::List { limit } => { + match commands::list(limit.unwrap_or(5)).await { + Ok(_) => {} + Err(error) => { + log::error!("Error listing publications: {}", error); + } + }; + } + + Commands::Reply { id } => { + match commands::reply(id).await { + Ok(_) => {} + Err(error) => { + log::error!("Error replying to publication: {}", error); + } + }; + } } } diff --git a/sshn-cli/src/publication.rs b/sshn-cli/src/publication.rs new file mode 100644 index 0000000..2552964 --- /dev/null +++ b/sshn-cli/src/publication.rs @@ -0,0 +1,61 @@ +use crate::error::{Error, Result}; + +pub async fn list_publications(limit: usize) -> Result> { + let client = sshn_lib::Client::new(None); + + let publications = client.get_publications_list(limit as i64).await?; + + Ok(publications + .housing_publications + .ok_or(Error::MissingPublications)? + .nodes + .ok_or(Error::MissingPublications)? + .edges + .ok_or(Error::MissingPublications)? + .into_iter() + .filter_map(|publication| { + let publication = publication?.node?; + + // let city = publication.unit?.location?.city?.name.as_ref()?.to_string(); + let rent = publication.unit?.gross_rent.as_ref()?.exact; + + Some(Publication { + id: publication.id, + name: String::new(), + city: String::new(), + nr_of_applicants: publication.total_number_of_applications, + rent, + }) + }) + .collect()) +} + +pub struct Publication { + id: String, + name: String, + city: String, + nr_of_applicants: i64, + rent: f64, +} + +impl Publication { + pub fn as_row(self) -> Vec { + vec![ + self.name, + self.city, + self.nr_of_applicants.to_string(), + self.rent.to_string(), + self.id, + ] + } + + pub fn row_labels() -> Vec { + vec![ + String::from("Name"), + String::from("City"), + String::from("Number of applicants"), + String::from("Gross rent"), + String::from("ID"), + ] + } +} diff --git a/sshn-cli/src/secrets.rs b/sshn-cli/src/secrets.rs new file mode 100644 index 0000000..be9f075 --- /dev/null +++ b/sshn-cli/src/secrets.rs @@ -0,0 +1,82 @@ +use keyring::Entry; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use sshn_lib::{AuthenticatedClient, Tokens}; + +pub use crate::error::Result; +use crate::{auth, error::Error}; + +const SERVICE_NAME: &str = "SSHN-cli"; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Credentials { + username: String, + password: String, +} + +impl Credentials { + pub fn new, P: Into>(username: U, password: P) -> Self { + Self { + username: username.into(), + password: password.into(), + } + } +} + +pub fn set, T: Serialize>(identifier: I, data: &T) -> Result<()> { + let user = whoami::username(); + + let entry_name = format!("{}-{}", identifier.as_ref(), SERVICE_NAME); + + let entry = Entry::new(&entry_name, &user)?; + + let data = serde_json::to_string(data)?; + + entry.set_password(&data)?; + + Ok(()) +} + +pub fn get, T: DeserializeOwned>(identifier: I) -> Result { + let user = whoami::username(); + + let entry_name = format!("{}-{}", identifier.as_ref(), SERVICE_NAME); + + let entry = Entry::new(&entry_name, &user)?; + + let data = entry.get_password()?; + + let data = serde_json::from_str(&data)?; + + Ok(data) +} + +pub async fn get_client() -> Result { + let client = sshn_lib::Client::new(None); + + if let Ok(tokens) = get::<_, Tokens>("tokens") { + if !tokens.access_token().has_expired() { + return Ok(AuthenticatedClient::new(None, tokens)); + } else { + if !tokens.refresh_token().has_expired() { + return Ok(client + .login(sshn_lib::LoginType::RefreshToken { + token: tokens.refresh_token().content().to_string(), + }) + .await?); + } + } + } + + log::info!("Tokens expired, logging in using credentials"); + + if let Ok(credentials) = get::<_, Credentials>("credentials") { + return auth::password_login( + credentials.username, + credentials.password, + Default::default(), + ) + .await; + } + + Err(Error::MissingCredentials) +} diff --git a/sshn-lib/Cargo.toml b/sshn-lib/Cargo.toml index 0101efc..cb53542 100644 --- a/sshn-lib/Cargo.toml +++ b/sshn-lib/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] base64 = "0.22.0" -chrono = "0.4.38" +chrono = { version = "0.4.38", features = ["serde"] } digest = "0.10.7" graphql_client = "0.14.0" log = "0.4.21" diff --git a/sshn-lib/src/client.rs b/sshn-lib/src/client.rs index 4200585..9c2e293 100644 --- a/sshn-lib/src/client.rs +++ b/sshn-lib/src/client.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use chrono::{Duration, Utc}; use graphql_client::GraphQLQuery; use serde::{de::DeserializeOwned, Serialize}; @@ -12,7 +11,7 @@ use crate::{ post_application::{self, HousingApplyState}, GetIdentityConfig, GetPublicationsList, GraphqlResponse, PostApplication, }, - tokens::{RefreshTokenResponse, Token, TokenType}, + tokens::{LoginResponse, Tokens}, }; pub struct Client { @@ -22,6 +21,7 @@ pub struct Client { pub enum LoginType { AuthCode { code: String, verifier: String }, + RefreshToken { token: String }, Password { username: String, password: String }, } @@ -33,7 +33,7 @@ impl Client { } } - pub async fn login(self, login_type: LoginType) -> Result { + pub async fn auth(&self, login_type: LoginType) -> Result { let mut params = HashMap::new(); params.insert("client_id", CLIENT_ID); @@ -42,9 +42,14 @@ impl Client { LoginType::AuthCode { code, verifier } => { params.insert("grant_type", "authorization_code"); params.insert("redirect_uri", REDIRECT_URI); + params.insert("code_verifier", &verifier); params.insert("code", code); } + LoginType::RefreshToken { token } => { + params.insert("grant_type", "refresh_token"); + params.insert("refresh_token", token); + } LoginType::Password { username, password } => { params.insert("grant_type", "password"); @@ -72,26 +77,19 @@ impl Client { return Err(Error::HttpRequest(err)); }; - let tokens = response.json::().await?; + let response_data = response.json::().await?; - let access_token = Token::new( - tokens.access_token, - Utc::now() + Duration::seconds(tokens.expires_in), - TokenType::Access, - ); + Ok(response_data.into()) + } - let refresh_token = Token::new( - tokens.refresh_token, - Utc::now() + Duration::seconds(tokens.refresh_expires_in), - TokenType::Refresh, - ); + pub async fn login(self, login_type: LoginType) -> Result { + let tokens = self.auth(login_type).await?; let authenticated_client = AuthenticatedClient { - graphql_url: self.graphql_url, - http_client: self.http_client, - token_url: TOKEN_URL.to_string(), - access_token, - refresh_token, + graphql_url: self.graphql_url.clone(), + http_client: reqwest::Client::new(), + client: self, + tokens, }; Ok(authenticated_client) @@ -149,57 +147,36 @@ impl Client { pub struct AuthenticatedClient { graphql_url: String, - token_url: String, http_client: reqwest::Client, - access_token: Token, - refresh_token: Token, + client: Client, + tokens: Tokens, +} + +impl Into for AuthenticatedClient { + fn into(self) -> Tokens { + self.tokens + } } impl AuthenticatedClient { - async fn refresh_tokens(&mut self) -> Result<()> { - if self.refresh_token.expires() < Utc::now() { - return Err(Error::TokenExpired); + pub fn new(graphql_url: Option, tokens: Tokens) -> Self { + Self { + graphql_url: graphql_url.clone().unwrap_or(GRAPHQL_URL.to_string()), + http_client: reqwest::Client::new(), + client: Client::new(graphql_url), + tokens, } - - let mut params = HashMap::new(); - - params.insert("client_id", CLIENT_ID); - params.insert("grant_type", "refresh_token"); - params.insert("refresh_token", self.refresh_token.as_ref()); - - let body = serde_urlencoded::to_string(¶ms)?; - - let response = self - .http_client - .post(&self.token_url) - .body(body) - .header( - reqwest::header::CONTENT_TYPE, - "application/x-www-form-urlencoded", - ) - .send() - .await?; - - let tokens = response.json::().await?; - - self.access_token = Token::new( - tokens.access_token, - Utc::now() + Duration::seconds(tokens.expires_in), - TokenType::Access, - ); - - self.refresh_token = Token::new( - tokens.refresh_token, - Utc::now() + Duration::seconds(tokens.refresh_expires_in), - TokenType::Refresh, - ); - - Ok(()) } async fn check_expiration(&mut self) -> Result<()> { - if self.access_token.expires() < Utc::now() { - self.refresh_tokens().await?; + if self.tokens.access_token().has_expired() { + if !self.tokens.refresh_token().has_expired() { + let token = self.tokens.refresh_token().content().to_string(); + + self.tokens = self.client.auth(LoginType::RefreshToken { token }).await?; + } else { + return Err(Error::TokenExpired); + } } Ok(()) @@ -211,7 +188,7 @@ impl AuthenticatedClient { let response = self .http_client .post(&self.graphql_url) - .bearer_auth(self.access_token.as_ref()) + .bearer_auth(self.tokens.access_token().as_ref()) .json(query) .send() .await?; @@ -223,8 +200,12 @@ impl AuthenticatedClient { Ok(response_body.data) } - pub fn tokens(&self) -> (&Token, &Token) { - (&self.access_token, &self.refresh_token) + pub fn tokens(&self) -> &Tokens { + &self.tokens + } + + pub fn client(&self) -> &Client { + &self.client } /// Reply to a publication, given that publications id. diff --git a/sshn-lib/src/constants.rs b/sshn-lib/src/constants.rs index 20f649f..c406af0 100644 --- a/sshn-lib/src/constants.rs +++ b/sshn-lib/src/constants.rs @@ -5,6 +5,6 @@ pub const TOKEN_URL: &str = pub const REDIRECT_URI: &str = "https://mijn.sshn.nl/authentication/callback"; -pub const LOCALE: &str = "nl-NL"; +pub const LOCALE: &str = "en-US"; pub const CLIENT_ID: &str = "portal-legacy"; diff --git a/sshn-lib/src/error.rs b/sshn-lib/src/error.rs index f6962b1..be7dae3 100644 --- a/sshn-lib/src/error.rs +++ b/sshn-lib/src/error.rs @@ -10,6 +10,8 @@ pub enum Error { HttpRequest(#[from] reqwest::Error), #[error("The refresh token expired")] TokenExpired, + #[error("Missing refresh token to get new tokens")] + MissingRefreshToken, #[error("The authentication endpoint is missing")] NoAuthUrl, #[error("Failed to parse url: {0}")] diff --git a/sshn-lib/src/lib.rs b/sshn-lib/src/lib.rs index dfe0a11..987d476 100644 --- a/sshn-lib/src/lib.rs +++ b/sshn-lib/src/lib.rs @@ -5,9 +5,11 @@ mod queries; mod tokens; mod utils; -pub use crate::client::{AuthenticatedClient, Client, LoginType}; - -pub use utils::{generate_auth_url, get_code_challenge}; +pub use { + client::{AuthenticatedClient, Client, LoginType}, + tokens::{Token, TokenType, Tokens}, + utils::{generate_auth_url, get_code_challenge}, +}; #[cfg(test)] mod tests { diff --git a/sshn-lib/src/tokens.rs b/sshn-lib/src/tokens.rs index f885e3b..fbe5521 100644 --- a/sshn-lib/src/tokens.rs +++ b/sshn-lib/src/tokens.rs @@ -1,8 +1,8 @@ -use chrono::{DateTime, Utc}; -use serde::Deserialize; +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; #[derive(Deserialize, Debug)] -pub struct RefreshTokenResponse { +pub struct LoginResponse { pub access_token: String, pub expires_in: i64, pub refresh_expires_in: i64, @@ -12,14 +12,31 @@ pub struct RefreshTokenResponse { // session_state: String, } -#[derive(Debug)] +impl Into for LoginResponse { + fn into(self) -> Tokens { + Tokens::new( + Token::new( + self.refresh_token, + Utc::now() + Duration::seconds(self.refresh_expires_in), + TokenType::Refresh, + ), + Token::new( + self.access_token, + Utc::now() + Duration::seconds(self.expires_in), + TokenType::Access, + ), + ) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Token { r#type: TokenType, content: String, expires: DateTime, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Serialize, Deserialize, Clone)] pub enum TokenType { #[default] Access, @@ -31,7 +48,7 @@ impl Default for Token { Token { content: String::new(), expires: Utc::now(), - ..Default::default() + r#type: TokenType::Access, } } } @@ -58,4 +75,31 @@ impl Token { pub fn expires(&self) -> DateTime { self.expires } + + pub fn has_expired(&self) -> bool { + self.expires <= Utc::now() + } +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +pub struct Tokens { + refresh_token: Token, + access_token: Token, +} + +impl Tokens { + pub fn new(refresh_token: Token, access_token: Token) -> Self { + Self { + refresh_token, + access_token, + } + } + + pub fn refresh_token(&self) -> &Token { + &self.refresh_token + } + + pub fn access_token(&self) -> &Token { + &self.access_token + } }