impl FromRequest for ApiToken

and use it for the posts API
This commit is contained in:
Baptiste Gelez 2018-10-22 15:09:47 +01:00
parent 663ec52fea
commit 9a13d804c5
2 changed files with 68 additions and 8 deletions

View File

@ -1,6 +1,12 @@
use chrono::NaiveDateTime;
use diesel::{self, ExpressionMethods, QueryDsl, RunQueryDsl};
use rocket::{
Outcome,
http::Status,
request::{self, FromRequest, Request}
};
use db_conn::DbConn;
use schema::api_tokens;
#[derive(Clone, Queryable)]
@ -37,4 +43,46 @@ impl ApiToken {
get!(api_tokens);
insert!(api_tokens, NewApiToken);
find_by!(api_tokens, find_by_value, value as String);
fn can(&self, what: &'static str, scope: &'static str) -> bool {
let full_scope = what.to_owned() + ":" + scope;
for s in self.scopes.split('+') {
if s == what || s == full_scope {
return true
}
}
false
}
pub fn can_read(&self, what: &'static str) -> bool {
self.can("read", what)
}
pub fn can_write(&self, what: &'static str) -> bool {
self.can("write", what)
}
}
impl<'a, 'r> FromRequest<'a, 'r> for ApiToken {
type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<ApiToken, ()> {
let headers: Vec<_> = request.headers().get("Authorization").collect();
if headers.len() != 1 {
return Outcome::Failure((Status::BadRequest, ()));
}
let mut parsed_header = headers[0].split(' ');
let auth_type = parsed_header.next().expect("Expect a token type");
let val = parsed_header.next().expect("Expect a token value");
if auth_type == "Bearer" {
let conn = request.guard::<DbConn>().expect("Couldn't connect to DB");
if let Some(token) = ApiToken::find_by_value(&*conn, val.to_string()) {
return Outcome::Success(token);
}
}
return Outcome::Forward(());
}
}

View File

@ -7,19 +7,31 @@ use serde_qs;
use plume_api::posts::PostEndpoint;
use plume_models::{
Connection,
api_tokens::ApiToken,
db_conn::DbConn,
posts::Post,
};
#[get("/posts/<id>")]
fn get(id: i32, conn: DbConn) -> Json<serde_json::Value> {
fn get(id: i32, conn: DbConn, token: ApiToken) -> Json<serde_json::Value> {
if token.can_read("posts") {
let post = <Post as Provider<Connection>>::get(&*conn, id).ok();
Json(json!(post))
} else {
Json(json!({
"error": "Unauthorized"
}))
}
}
#[get("/posts")]
fn list(conn: DbConn, uri: &Origin) -> Json<serde_json::Value> {
fn list(conn: DbConn, uri: &Origin, token: ApiToken) -> Json<serde_json::Value> {
if token.can_read("posts") {
let query: PostEndpoint = serde_qs::from_str(uri.query().unwrap_or("")).expect("api::list: invalid query error");
let post = <Post as Provider<Connection>>::list(&*conn, query);
Json(json!(post))
}
} else {
Json(json!({
"error": "Unauthorized"
}))
}}