diff --git a/plume-models/src/admin.rs b/plume-models/src/admin.rs index a4fa0455..325198ea 100644 --- a/plume-models/src/admin.rs +++ b/plume-models/src/admin.rs @@ -1,38 +1,42 @@ use crate::users::User; use rocket::{ http::Status, - request::{self, FromRequest, Request}, + request::{self, FromRequestAsync, Request}, Outcome, }; /// Wrapper around User to use as a request guard on pages reserved to admins. pub struct Admin(pub User); -impl<'a, 'r> FromRequest<'a, 'r> for Admin { +impl<'a, 'r> FromRequestAsync<'a, 'r> for Admin { type Error = (); - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let user = request.guard::()?; - if user.is_admin() { - Outcome::Success(Admin(user)) - } else { - Outcome::Failure((Status::Unauthorized, ())) - } + fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> { + Box::pin(async move { + let user = try_outcome!(request.guard::()); + if user.is_admin() { + Outcome::Success(Admin(user)) + } else { + Outcome::Failure((Status::Unauthorized, ())) + } + }) } } /// Same as `Admin` but for moderators. pub struct Moderator(pub User); -impl<'a, 'r> FromRequest<'a, 'r> for Moderator { +impl<'a, 'r> FromRequestAsync<'a, 'r> for Moderator { type Error = (); - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let user = request.guard::()?; - if user.is_moderator() { - Outcome::Success(Moderator(user)) - } else { - Outcome::Failure((Status::Unauthorized, ())) - } + fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> { + Box::pin(async move { + let user = try_outcome!(request.guard::()); + if user.is_moderator() { + Outcome::Success(Moderator(user)) + } else { + Outcome::Failure((Status::Unauthorized, ())) + } + }) } } diff --git a/plume-models/src/api_tokens.rs b/plume-models/src/api_tokens.rs index 5e2eca07..5415111c 100644 --- a/plume-models/src/api_tokens.rs +++ b/plume-models/src/api_tokens.rs @@ -3,7 +3,7 @@ use chrono::NaiveDateTime; use diesel::{self, ExpressionMethods, QueryDsl, RunQueryDsl}; use rocket::{ http::Status, - request::{self, FromRequest, Request}, + request::{self, FromRequestAsync, Request}, Outcome, }; @@ -76,34 +76,37 @@ pub enum TokenError { DbError, } -impl<'a, 'r> FromRequest<'a, 'r> for ApiToken { +impl<'a, 'r> FromRequestAsync<'a, 'r> for ApiToken { type Error = TokenError; - fn from_request(request: &'a Request<'r>) -> request::Outcome { - let headers: Vec<_> = request.headers().get("Authorization").collect(); - if headers.len() != 1 { - return Outcome::Failure((Status::BadRequest, TokenError::NoHeader)); - } - - let mut parsed_header = headers[0].split(' '); - let auth_type = parsed_header.next().map_or_else( - || Outcome::Failure((Status::BadRequest, TokenError::NoType)), - Outcome::Success, - )?; - let val = parsed_header.next().map_or_else( - || Outcome::Failure((Status::BadRequest, TokenError::NoValue)), - Outcome::Success, - )?; - - if auth_type == "Bearer" { - let conn = request - .guard::() - .map_failure(|_| (Status::InternalServerError, TokenError::DbError))?; - if let Ok(token) = ApiToken::find_by_value(&*conn, val) { - return Outcome::Success(token); + fn from_request(request: &'a Request<'r>) -> request::FromRequestFuture<'a, Self, Self::Error> { + Box::pin(async move { + let headers: Vec<_> = request.headers().get("Authorization").collect(); + if headers.len() != 1 { + return Outcome::Failure((Status::BadRequest, TokenError::NoHeader)); } - } - Outcome::Forward(()) + let mut parsed_header = headers[0].split(' '); + if let Some(auth_type) = parsed_header.next() { + if let Some(val) = parsed_header.next() { + if auth_type == "Bearer" { + if let Outcome::Success(conn) = request.guard::() { + if let Ok(token) = ApiToken::find_by_value(&*conn, val) { + return Outcome::Success(token); + } + } else { + return Outcome::Failure((Status::InternalServerError, TokenError::DbError)); + } + } + } else { + return Outcome::Failure((Status::BadRequest, TokenError::NoValue)); + } + } else { + return Outcome::Failure((Status::BadRequest, TokenError::NoType)); + } + + + Outcome::Forward(()) + }) } }