feat(sdk): started work on oauth support
continuous-integration/drone/push Build is failing
Details
continuous-integration/drone/push Build is failing
Details
parent
8cf965ef81
commit
cc322a8182
@ -0,0 +1,54 @@
|
||||
use hyper::{body, client::HttpConnector, Body, Client, Request};
|
||||
use hyper_tls::HttpsConnector;
|
||||
|
||||
use crate::types::{Error, ErrorKind, Result};
|
||||
|
||||
pub struct HttpClient {
|
||||
client: Client<HttpConnector>,
|
||||
client_tls: Client<HttpsConnector<HttpConnector>>,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn new() -> Self {
|
||||
let http_connector = HttpConnector::new();
|
||||
let client = Client::builder().build(http_connector);
|
||||
|
||||
let https_connector = HttpsConnector::new();
|
||||
let client_tls = Client::builder().build(https_connector);
|
||||
|
||||
Self { client, client_tls }
|
||||
}
|
||||
pub async fn request(&self, req: Request<Body>) -> Result<Vec<u8>> {
|
||||
let secure = match req.uri().scheme() {
|
||||
Some(scheme) => scheme.as_str() == "https",
|
||||
None => {
|
||||
return Err(Error::new(
|
||||
ErrorKind::CreateHttpRequest,
|
||||
"Missing request scheme",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let response_result = if secure {
|
||||
self.client_tls.request(req).await
|
||||
} else {
|
||||
self.client.request(req).await
|
||||
};
|
||||
|
||||
let response = response_result.map_err(|err| {
|
||||
Error::new(
|
||||
ErrorKind::CreateHttpRequest,
|
||||
format!("Failed to create http request: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let body = body::to_bytes(response.into_body()).await.map_err(|err| {
|
||||
Error::new(
|
||||
ErrorKind::CreateHttpRequest,
|
||||
format!("Failed to read response body in http request: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Vec::from(body))
|
||||
}
|
||||
}
|
@ -0,0 +1,92 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
http::HttpClient,
|
||||
types::{Error, ErrorKind, Result},
|
||||
};
|
||||
|
||||
use hyper::{header::CONTENT_TYPE, Body, Request};
|
||||
use rocket::serde::{json::from_slice as parse_json_from_slice, Deserialize};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct AccessTokenResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: u16,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
impl AccessTokenResponse {
|
||||
pub fn access_token(&self) -> &str {
|
||||
&self.access_token
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_access_token<S: AsRef<str>>(
|
||||
http_client: &HttpClient,
|
||||
token_url: S,
|
||||
code: S,
|
||||
redirect_uri: S,
|
||||
client_id: S,
|
||||
client_secret: &Option<String>,
|
||||
) -> Result<AccessTokenResponse> {
|
||||
let mut form_data = HashMap::new();
|
||||
|
||||
form_data.insert("grant_type", "authorization_code");
|
||||
form_data.insert("code", code.as_ref());
|
||||
form_data.insert("redirect_uri", redirect_uri.as_ref());
|
||||
form_data.insert("client_id", client_id.as_ref());
|
||||
|
||||
match client_secret.as_ref() {
|
||||
Some(client_secret) => {
|
||||
form_data.insert("client_secret", client_secret.as_ref());
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
let encoded_form_data = serde_urlencoded::to_string(form_data).map_err(|err| {
|
||||
Error::new(
|
||||
ErrorKind::CreateHttpRequest,
|
||||
format!("Failed to create http request for oauth token: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
let url: hyper::Uri = token_url.as_ref().parse().map_err(|_| {
|
||||
Error::new(
|
||||
ErrorKind::BadConfig,
|
||||
"Failed to parse token url from config",
|
||||
)
|
||||
})?;
|
||||
|
||||
let request = Request::builder()
|
||||
.uri(url)
|
||||
.method("POST")
|
||||
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
|
||||
.body(Body::from(encoded_form_data))
|
||||
.map_err(|_| {
|
||||
Error::new(
|
||||
ErrorKind::CreateHttpRequest,
|
||||
"Failed to create http request to request oauth token",
|
||||
)
|
||||
})?;
|
||||
|
||||
let token_response = http_client.request(request).await.map_err(|err| {
|
||||
Error::new(
|
||||
ErrorKind::InternalError,
|
||||
format!("Failed to request oauth access token: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
// println!("{}", String::from_utf8(token_response.clone()).unwrap());
|
||||
|
||||
let access_token_response: AccessTokenResponse = parse_json_from_slice(&token_response)
|
||||
.map_err(|err| {
|
||||
Error::new(
|
||||
ErrorKind::Parse,
|
||||
format!("Invalid response when fetching oauth access token: {}", err),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(access_token_response)
|
||||
}
|
@ -1,7 +1,9 @@
|
||||
mod boxes;
|
||||
mod login;
|
||||
mod logout;
|
||||
mod oauth2;
|
||||
|
||||
pub use boxes::*;
|
||||
pub use login::login as mail_login_handler;
|
||||
pub use logout::logout as mail_logout_handler;
|
||||
pub use oauth2::*;
|
||||
|
@ -0,0 +1,5 @@
|
||||
mod redirect;
|
||||
mod tokens;
|
||||
|
||||
pub use redirect::handle_redirect as oauth_redirect_handler;
|
||||
pub use tokens::get_tokens as oauth_get_tokens_handler;
|
@ -0,0 +1,85 @@
|
||||
use rocket::{
|
||||
serde::{json::Json, Deserialize},
|
||||
State,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
http::HttpClient,
|
||||
oauth2::get_access_token,
|
||||
state::Config,
|
||||
types::{ErrResponse, ErrorKind, OkResponse, ResponseResult},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(crate = "rocket::serde", rename_all = "camelCase")]
|
||||
pub enum ApplicationType {
|
||||
Desktop,
|
||||
Web,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(crate = "rocket::serde", rename_all = "camelCase")]
|
||||
pub struct OAuthState {
|
||||
provider: String,
|
||||
application: ApplicationType,
|
||||
}
|
||||
|
||||
impl OAuthState {
|
||||
fn provider(&self) -> &str {
|
||||
&self.provider
|
||||
}
|
||||
|
||||
fn application_type(&self) -> &ApplicationType {
|
||||
&self.application
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/redirect?<code>&<state>&<scope>&<error>")]
|
||||
pub async fn handle_redirect(
|
||||
code: Option<String>,
|
||||
state: Json<OAuthState>,
|
||||
scope: Option<String>,
|
||||
error: Option<String>,
|
||||
config: &State<Config>,
|
||||
http_client: &State<HttpClient>,
|
||||
) -> ResponseResult<String> {
|
||||
if code.is_some() && scope.is_some() {
|
||||
let provider = match config.oauth2().providers().get(state.provider()) {
|
||||
Some(provider) => provider,
|
||||
None => {
|
||||
return Err(ErrResponse::new(
|
||||
ErrorKind::BadRequest,
|
||||
"Could not find requested oauth provider",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let redirect_uri = format!("{}/mail/oauth2/redirect", config.external_host());
|
||||
let token_url = provider.token_url();
|
||||
let secret_token = provider.secret_token();
|
||||
let public_token = provider.public_token();
|
||||
let code = code.unwrap();
|
||||
|
||||
let access_token_response = get_access_token(
|
||||
&http_client,
|
||||
token_url,
|
||||
code.as_str(),
|
||||
redirect_uri.as_str(),
|
||||
public_token,
|
||||
secret_token,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| ErrResponse::from(err).into())?;
|
||||
|
||||
println!("{}", access_token_response.access_token());
|
||||
|
||||
Ok(OkResponse::new(token_url.to_string()))
|
||||
} else if error.is_some() {
|
||||
Err(ErrResponse::new(ErrorKind::BadRequest, "yeet"))
|
||||
} else {
|
||||
Err(ErrResponse::new(
|
||||
ErrorKind::BadRequest,
|
||||
"Missing required params",
|
||||
))
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rocket::State;
|
||||
|
||||
use crate::{
|
||||
state::Config,
|
||||
types::{OkResponse, ResponseResult},
|
||||
};
|
||||
|
||||
#[get("/tokens")]
|
||||
pub fn get_tokens(config: &State<Config>) -> ResponseResult<HashMap<String, String>> {
|
||||
let public_tokens: HashMap<String, String> = config
|
||||
.oauth2()
|
||||
.providers()
|
||||
.iter()
|
||||
.map(|(key, value)| return (key.to_string(), value.public_token().to_string()))
|
||||
.collect();
|
||||
|
||||
Ok(OkResponse::new(public_tokens))
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rocket::serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct OAuth2 {
|
||||
providers: HashMap<String, Provider>,
|
||||
}
|
||||
|
||||
impl OAuth2 {
|
||||
pub fn providers(&self) -> &HashMap<String, Provider> {
|
||||
&self.providers
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OAuth2 {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(crate = "rocket::serde")]
|
||||
pub struct Provider {
|
||||
public_token: String,
|
||||
secret_token: Option<String>,
|
||||
token_url: String,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn public_token(&self) -> &str {
|
||||
&self.public_token
|
||||
}
|
||||
|
||||
pub fn secret_token(&self) -> &Option<String> {
|
||||
&self.secret_token
|
||||
}
|
||||
|
||||
pub fn token_url(&self) -> &str {
|
||||
&self.token_url
|
||||
}
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
import { AuthType, ConnectionSecurity } from "@dust-mail/structures";
|
||||
|
||||
export default interface MultiServerLoginOptions {
|
||||
username: string;
|
||||
password: string;
|
||||
domain: string;
|
||||
port: number;
|
||||
security: ConnectionSecurity;
|
||||
loginType: AuthType[];
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
import { Result } from "./result";
|
||||
|
||||
export default interface OAuth2Client {
|
||||
getGrant: (
|
||||
providerName: string,
|
||||
grantUrl: string,
|
||||
tokenUrl: string,
|
||||
scopes: string[]
|
||||
) => Promise<Result<string>>;
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
import { createBaseError } from "./parseError";
|
||||
|
||||
import { ErrorResult } from "@interfaces/result";
|
||||
|
||||
export const NotLoggedIn = (): ErrorResult =>
|
||||
createBaseError({
|
||||
kind: "NotLoggedIn",
|
||||
message: "Could not find session token in local storage"
|
||||
});
|
||||
|
||||
export const NotImplemented = (feature?: string): ErrorResult =>
|
||||
createBaseError({
|
||||
kind: "NotImplemented",
|
||||
message: `The feature ${
|
||||
feature ? `'${feature}'` : ""
|
||||
} is not yet implemented`
|
||||
});
|
||||
|
||||
export const MissingRequiredParam = (): ErrorResult =>
|
||||
createBaseError({
|
||||
kind: "MissingRequiredParam",
|
||||
message: "Missing a required parameter"
|
||||
});
|
@ -0,0 +1,143 @@
|
||||
import z from "zod";
|
||||
|
||||
import useFetchClient from "./useFetchClient";
|
||||
import useSettings from "./useSettings";
|
||||
|
||||
import { OAuthState } from "@dust-mail/structures";
|
||||
|
||||
import OAuth2Client from "@interfaces/oauth2";
|
||||
import { Result } from "@interfaces/result";
|
||||
|
||||
import { NotImplemented } from "@utils/defaultErrors";
|
||||
import { createBaseError, createResultFromUnknown } from "@utils/parseError";
|
||||
import parseZodOutput from "@utils/parseZodOutput";
|
||||
|
||||
const useGetPublicOAuthTokens = (): (() => Promise<
|
||||
Result<Record<string, string>>
|
||||
>) => {
|
||||
const fetch = useFetchClient();
|
||||
|
||||
return () =>
|
||||
fetch("/mail/oauth2/tokens")
|
||||
.then((response) => {
|
||||
if (!response.ok) {
|
||||
return response;
|
||||
}
|
||||
|
||||
const output = z
|
||||
.record(z.string(), z.string())
|
||||
.safeParse(response.data);
|
||||
|
||||
return parseZodOutput(output);
|
||||
})
|
||||
.catch(createResultFromUnknown);
|
||||
};
|
||||
|
||||
const findProviderToken = (
|
||||
providerName: string,
|
||||
tokens: Record<string, string>
|
||||
): [string, string] | null => {
|
||||
for (const [key, value] of Object.entries(tokens)) {
|
||||
const isProvider = providerName
|
||||
.toLowerCase()
|
||||
.includes(key.trim().toLowerCase());
|
||||
|
||||
if (isProvider) return [value, key];
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const useOAuth2Client = (): OAuth2Client => {
|
||||
const isTauri: boolean = "__TAURI__" in window;
|
||||
|
||||
const getPublicTokens = useGetPublicOAuthTokens();
|
||||
|
||||
const [settings] = useSettings();
|
||||
|
||||
return {
|
||||
async getGrant(providerName, authUrl, tokenUrl, scopes) {
|
||||
const authUrlResult = z.string().url().safeParse(authUrl);
|
||||
|
||||
const authUrlOutput = parseZodOutput(authUrlResult);
|
||||
|
||||
if (!authUrlOutput.ok) {
|
||||
return authUrlOutput;
|
||||
}
|
||||
|
||||
if (settings.httpServerUrl === null)
|
||||
return createBaseError({
|
||||
kind: "NoBackend",
|
||||
message: "Backend server url is not set"
|
||||
});
|
||||
|
||||
const publicTokensResult = await getPublicTokens().catch(
|
||||
createResultFromUnknown
|
||||
);
|
||||
|
||||
if (!publicTokensResult.ok) {
|
||||
return publicTokensResult;
|
||||
}
|
||||
|
||||
const publicTokens = publicTokensResult.data;
|
||||
|
||||
const providerDetails = findProviderToken(providerName, publicTokens);
|
||||
|
||||
if (providerDetails === null)
|
||||
return createBaseError({
|
||||
kind: "NoOAuthToken",
|
||||
message:
|
||||
"Could not find a oauth token on remote Dust-Mail server to authorize with email provider"
|
||||
});
|
||||
|
||||
const providerToken = providerDetails[0];
|
||||
const providerId = providerDetails[1];
|
||||
|
||||
if (!isTauri) {
|
||||
if (typeof window !== "undefined" && "open" in window) {
|
||||
const url = new URL(authUrlOutput.data);
|
||||
const redirectUri = new URL(
|
||||
"/mail/oauth2/redirect",
|
||||
settings.httpServerUrl
|
||||
);
|
||||
|
||||
const state: OAuthState = {
|
||||
provider: providerId,
|
||||
application: isTauri ? "desktop" : "web"
|
||||
};
|
||||
|
||||
// https://www.rfc-editor.org/rfc/rfc6749#section-1.1
|
||||
url.searchParams.set("response_type", "code");
|
||||
url.searchParams.set("redirect_uri", redirectUri.toString());
|
||||
url.searchParams.set("client_id", providerToken);
|
||||
url.searchParams.set("scope", scopes.join(" "));
|
||||
url.searchParams.set("state", JSON.stringify(state));
|
||||
url.searchParams.set("access_type", "offline");
|
||||
|
||||
const oauthLoginWindow = window.open(url, "_blank", "popup");
|
||||
|
||||
if (oauthLoginWindow === null)
|
||||
return createBaseError({
|
||||
kind: "UnsupportedEnvironment",
|
||||
message:
|
||||
"Your browser environment does not support intercommunication between windows"
|
||||
});
|
||||
|
||||
oauthLoginWindow.addEventListener("message", console.log);
|
||||
|
||||
return { ok: true as const, data: "" };
|
||||
} else {
|
||||
return createBaseError({
|
||||
kind: "UnsupportedEnvironment",
|
||||
message:
|
||||
"Your browser environment does not support opening a new window"
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return NotImplemented("oauth-grant-tauri");
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
export default useOAuth2Client;
|