- Published on
FastAPI Role Base Access Control With JWT
- Authors
- Name
FastAPI is a modern, high-performance, web framework used to build APIs with Python 3.8+. It is one of the fastest Python frameworks available. In every framework, authentication and authorization are important sections of an API. In this article let’s implement Role-based access control with JWT in FastAPI.
Prerequisites
- Python programming knowledge
- Basic knowledge about FastAPI
Before you start you have to install these Python modules.
- fastapi
- pydantic
- uvicorn[standard]
- passlib[bcrypt]
- python-jose[cryptography]
Setting Up The Environment
Let’s create two API endpoints in the main.py file.
from fastapi import FastAPI
app = FastAPI()
@app.get("/hello")
def hello_func():
return "Hello World"
@app.get("/data")
def get_data():
return {"data": "This is important data"}
Let’s create a User Model and Token Model in models.py
from pydantic import BaseModel
class User(BaseModel):
username: str | None = None
email: str | None = None
role: str | None = None
disabled: bool| None = None
hashed_password: str | None = None
class Token(BaseModel):
access_token: str | None = None
refresh_token: str | None = None
For this tutorial, I will create a Python dictionary containing dummy users in data.py. Also, I will create another list for store refresh tokens. You can use any database for this like PostgreSQL, MongoDB, etc.
fake_user_db = [
{
"username": "johndoe",
"email": "john@emaik.com",
"role": "admin",
"hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
"is_active": True
},
{
"username": "alice",
"email": "al8ce@emaik.com",
"role": "user",
"hashed_password": "hdjsbdvdhxbzbsksjdbdbzjdhh45tbdbd7bdbd",
"is_active": True
}
]
refresh_tokens = []
How This Works
Ok. Now we are ready to implement authentication to our API. Let’s start. This is the process. Users should log in using a username and password. This is a post request. Next, the backend validates the user and creates an access token and request token. An access token has a short lifetime. Refresh token has a long lifetime. If a valid user, the backend will send a response with these two tokens. When sending a request to a protected endpoint, the user should attach this access token as a request header. If the access token expires, the user can get a new access token by sending a request to the backend with a refresh token.
Role Based Access Control (RBAC)
FastAPI provides several ways to deal with security. Here we use the OAuth2 with password flow. (You can get more details from this link.) We do that using the OAuth2PasswordBearer class. Also, we use passlib CryptContext to hash and verify passwords.
Let’s create auth.py. First, create instances of the above classes.
#auth.py
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
We pass the tokenUrl parameter to this class. This parameter contains the URL that the client uses to send the username and password in order to get a token. We haven’t created this endpoint yet. But we will create it later.
Now create a method to get the user details from db and another method to authenticate users. This method will check the password.
#auth.py
from db import User
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user(db, username: str):
if username in db:
user = db[username]
return User(**user)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not pwd_context.verify(plain_password, hashed_password):
return False
return user
Now let’s handle the JWT. To do that create some variables and a method to create JWT token.
#auth.py
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120
def create_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
We pass our data and token lifetime to this method and it returns the JWT token.
After that create a method to get details about the current logged-in user. We pass the token to this method. This method decodes the token and gets the user data from the token. And check the user exists in the DB. If there exists return the user. If not raise an exception
#auth.py
from typing import Annotated
from jose import JWTError, jwt
from fastapi import Depends, HTTPException, status
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)]):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
Also, we need to create another method to check if the user is disabled or not. If not raise an exception. The above code’s Depends() means dependency. For example, the get_current_active_user() method depends on the get_current_user() method. If you debug this code you can see when it comes to the get_current_active_user(), function, its dependent method, get_current_user() will run before the get_current_active_user() method.
Now let’s create RoleChecker class to check user roles. If the role has enough permission it returns True. If not raise an exception
#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
class RoleChecker:
def __init__(self, allowed_roles):
self.allowed_roles = allowed_roles
def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
if user.role in self.allowed_roles:
return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You don't have enough permissions")
We have to create one more method to validate the refresh token. When the access token expires, we have to request our refresh token to get a new access token.
#auth.py
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from data import refresh_tokens
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
try:
if token in refresh_tokens:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
role: str = payload.get("role")
if username is None or role is None:
raise credentials_exception
else:
raise credentials_exception
except (JWTError, ValidationError):
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user, token
The final auth.py file looks like this.
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from db import User
from jose import JWTError, jwt
from datetime import datetime, timedelta, timezone
from data import refresh_tokens
from typing import Annotated
from fastapi import Depends, HTTPException, status
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
SECRET_KEY = "hdhfh5jdnb7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
def get_user(db, username: str):
if username in db:
user = db[username]
return User(**user)
def authenticate_user(fake_db, username: str, password: str):
user = get_user(fake_db, username)
if not user:
return False
if not pwd_context.verify(plain_password, hashed_password):
return False
return user
def create_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def validate_refresh_token(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
try:
if token in refresh_tokens:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
role: str = payload.get("role")
if username is None or role is None:
raise credentials_exception
else:
raise credentials_exception
except (JWTError, ValidationError):
raise credentials_exception
user = get_user(fake_users_db, username=username)
if user is None:
raise credentials_exception
return user, token
class RoleChecker:
def __init__(self, allowed_roles):
self.allowed_roles = allowed_roles
def __call__(self, user: Annotated[User, Depends(get_current_active_user)]):
if user.role in self.allowed_roles:
return True
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="You don't have enough permissions")
Ok. We created the authentication and authorization parts. Now we can add these to our API endpoints. Before doing that we should create two endpoints. One is login and the other one is for refreshing tokens. Let’s go to main.py again.
from datetime import timedelta
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from auth import create_token, authenticate_user, RoleChecker, get_current_active_user, validate_refresh_token
from data import fake_users_db, refresh_tokens
from models import User, Token
app = FastAPI()
ACCESS_TOKEN_EXPIRE_MINUTES = 20
REFRESH_TOKEN_EXPIRE_MINUTES = 120
@app.get("/hello")
def hello_func():
return "Hello World"
@app.get("/data")
def get_data():
return {"data": "This is important data"}
@app.post("/token")
async def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]) -> Token:
user = authenticate_user(fake_users_db, form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(minutes=REFRESH_TOKEN_EXPIRE_MINUTES)
access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
refresh_tokens.append(refresh_token)
return Token(access_token=access_token, refresh_token=refresh_token)
@app.post("/refresh")
async def refresh_access_token(token_data: Annotated[tuple[User, str], Depends(validate_refresh_token)]):
user, token = token_data
access_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=access_token_expires)
refresh_token = create_token(data={"sub": user.username, "role": user.role}, expires_delta=refresh_token_expires)
refresh_tokens.remove(token)
refresh_tokens.append(refresh_token)
return Token(access_token=access_token, refresh_token=refresh_token)
Add RBAC To API
Now let’s add RBAC to our endpoints. For now, the “/data” endpoint is not protected. It can be accessed by anyone. You can check it using Swagger Docs or Postman. Now let’s add RBAC to this endpoint.
@app.get("/data")
def get_data(_: Annotated[bool, Depends(RoleChecker(allowed_roles=["admin"]))]):
return {"data": "This is important data"}
After doing this, it can be only accessed after login as an admin user. Like that you can add this to any endpoint that you want to protect. Now you know how to add RBAC to FastAPI. This is only one method. There are some other methods to do this. You can find it on the Internet.