|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import jwt # type: ignore[import-error] |
| 4 | +from fastapi import Depends, FastAPI, HTTPException, status |
| 5 | +from fastapi.security import OAuth2PasswordBearer |
| 6 | +from passlib.context import CryptContext |
| 7 | +from pydantic import BaseModel |
| 8 | + |
| 9 | +# auth secrets |
| 10 | +SECRET_KEY = "secret_key" |
| 11 | +ALGORITHM = "HS256" |
| 12 | +ACCESS_TOKEN_EXPIRE_MINUTES = 1 |
| 13 | + |
| 14 | +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| 15 | +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| 16 | + |
| 17 | +# fake db |
| 18 | +users_db = { |
| 19 | + "john_smith": { |
| 20 | + "username": "john_smith", |
| 21 | + "hashed_password": pwd_context.hash("john_smith"), |
| 22 | + "age": 25, |
| 23 | + "group": "users", |
| 24 | + }, |
| 25 | + "admin1": { |
| 26 | + "username": "admin1", |
| 27 | + "hashed_password": pwd_context.hash("john_smith"), |
| 28 | + "age": 25, |
| 29 | + "group": "admin", |
| 30 | + }, |
| 31 | +} |
| 32 | + |
| 33 | +app = FastAPI() |
| 34 | + |
| 35 | + |
| 36 | +class User(BaseModel): |
| 37 | + username: str |
| 38 | + email: str |
| 39 | + age: int |
| 40 | + role: str |
| 41 | + |
| 42 | + |
| 43 | +class TokenData(BaseModel): |
| 44 | + username: Optional[str] = None |
| 45 | + |
| 46 | + |
| 47 | +class ModelInferenceOutput(BaseModel): |
| 48 | + result: float |
| 49 | + |
| 50 | + |
| 51 | +async def get_current_user(token: str = Depends(oauth2_scheme)): |
| 52 | + credentials_exception = HTTPException( |
| 53 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 54 | + detail="Could not validate credentials", |
| 55 | + headers={"WWW-Authenticate": "Bearer"}, |
| 56 | + ) |
| 57 | + try: |
| 58 | + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| 59 | + username: str = payload.get("sub") |
| 60 | + if username is None: |
| 61 | + raise credentials_exception |
| 62 | + token_data = TokenData(username=username) |
| 63 | + except jwt.PyJWTError: |
| 64 | + raise credentials_exception |
| 65 | + user = users_db.get(token_data.username, None) # type: ignore[arg-type] |
| 66 | + if user is None: |
| 67 | + raise credentials_exception |
| 68 | + return user |
| 69 | + |
| 70 | + |
| 71 | +@app.get("/") |
| 72 | +def index(): |
| 73 | + return {"text": "ML model inference"} |
| 74 | + |
| 75 | + |
| 76 | +@app.get("/analysis/{data}", response_model=ModelInferenceOutput) |
| 77 | +def run_model_analysis(data: str, user: User = Depends(get_current_user)): |
| 78 | + if user.role != "admin": |
| 79 | + raise HTTPException(status_code=403, detail="Operation not permitted.") |
| 80 | + result = sum(map(data.lower().count, "aeiuyo")) / len(data) |
| 81 | + return {"result": result} |
0 commit comments