diff --git a/plume-models/src/api_tokens.rs b/plume-models/src/api_tokens.rs index 6245bf43..b5a244da 100644 --- a/plume-models/src/api_tokens.rs +++ b/plume-models/src/api_tokens.rs @@ -1,6 +1,12 @@ use chrono::NaiveDateTime; use diesel::{self, ExpressionMethods, QueryDsl, RunQueryDsl}; +use rocket::{ + Outcome, + http::Status, + request::{self, FromRequest, Request} +}; +use db_conn::DbConn; use schema::api_tokens; #[derive(Clone, Queryable)] @@ -37,4 +43,46 @@ impl ApiToken { get!(api_tokens); insert!(api_tokens, NewApiToken); find_by!(api_tokens, find_by_value, value as String); + + fn can(&self, what: &'static str, scope: &'static str) -> bool { + let full_scope = what.to_owned() + ":" + scope; + for s in self.scopes.split('+') { + if s == what || s == full_scope { + return true + } + } + false + } + + pub fn can_read(&self, what: &'static str) -> bool { + self.can("read", what) + } + + pub fn can_write(&self, what: &'static str) -> bool { + self.can("write", what) + } +} + +impl<'a, 'r> FromRequest<'a, 'r> for ApiToken { + type Error = (); + + fn from_request(request: &'a Request<'r>) -> request::Outcome { + let headers: Vec<_> = request.headers().get("Authorization").collect(); + if headers.len() != 1 { + return Outcome::Failure((Status::BadRequest, ())); + } + + let mut parsed_header = headers[0].split(' '); + let auth_type = parsed_header.next().expect("Expect a token type"); + let val = parsed_header.next().expect("Expect a token value"); + + if auth_type == "Bearer" { + let conn = request.guard::().expect("Couldn't connect to DB"); + if let Some(token) = ApiToken::find_by_value(&*conn, val.to_string()) { + return Outcome::Success(token); + } + } + + return Outcome::Forward(()); + } } diff --git a/src/api/posts.rs b/src/api/posts.rs index 56dc657f..25ce45bf 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -7,19 +7,31 @@ use serde_qs; use plume_api::posts::PostEndpoint; use plume_models::{ Connection, + api_tokens::ApiToken, db_conn::DbConn, posts::Post, }; #[get("/posts/")] -fn get(id: i32, conn: DbConn) -> Json { - let post = >::get(&*conn, id).ok(); - Json(json!(post)) +fn get(id: i32, conn: DbConn, token: ApiToken) -> Json { + if token.can_read("posts") { + let post = >::get(&*conn, id).ok(); + Json(json!(post)) + } else { + Json(json!({ + "error": "Unauthorized" + })) + } } #[get("/posts")] -fn list(conn: DbConn, uri: &Origin) -> Json { - let query: PostEndpoint = serde_qs::from_str(uri.query().unwrap_or("")).expect("api::list: invalid query error"); - let post = >::list(&*conn, query); - Json(json!(post)) -} +fn list(conn: DbConn, uri: &Origin, token: ApiToken) -> Json { + if token.can_read("posts") { + let query: PostEndpoint = serde_qs::from_str(uri.query().unwrap_or("")).expect("api::list: invalid query error"); + let post = >::list(&*conn, query); + Json(json!(post)) + } else { + Json(json!({ + "error": "Unauthorized" + })) + }}