-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
66 lines (56 loc) · 2.1 KB
/
app.py
File metadata and controls
66 lines (56 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import joblib
import re
from nltk.corpus import stopwords
import uvicorn
import nltk
from fastapi import FastAPI
from pydantic import BaseModel
# Download stopwords is now handled in the Dockerfile
# Define the set of English stopwords
stopwords_set = set(stopwords.words('english'))
# Create the FastAPI app instance
app = FastAPI()
# Load the trained model and vectorizer
try:
model = joblib.load("models/resampled_logistic_regression.joblib")
vectorizer = joblib.load("models/tfidf_vectorizer.joblib")
except FileNotFoundError:
raise RuntimeError("Model or vectorizer files not found. Please run the notebook to train and save them.")
# Define the input and output schemas
class Tweet(BaseModel):
text: str
# Define the cleaning function (copied from the notebook)
def clean_text(text):
text = text.lower()
text = re.sub(r'@[A-Za-z0-9_]+|https?://\S+|www\.\S+|\W+|\d+', ' ', text)
text = re.sub(r'\s+[a-zA-Z]\s+', ' ', text)
text = re.sub(r'\s+', ' ', text, flags=re.I)
text_words = text.split()
text = ' '.join([word for word in text_words if word not in stopwords_set])
return text
# Define the prediction function
def predict_class(text: str):
cleaned_text = clean_text(text)
# The vectorizer needs a list of strings
vectorized_text = vectorizer.transform([cleaned_text])
prediction = model.predict(vectorized_text)
return int(prediction[0])
# Define the root endpoint
@app.get("/")
def read_root():
return {"message": "Welcome to the Hate Speech Detection API!"}
# Define the health check endpoint
@app.get("/health")
def health_check():
return {"status": "ok"}
# Define the prediction endpoint
@app.post("/predict")
def predict(tweet: Tweet):
prediction = predict_class(tweet.text)
# Map the prediction number to a label
labels = {0: "Hate Speech", 1: "Offensive Language", 2: "Neither"}
predicted_label = labels[prediction]
return {"original_tweet": tweet.text, "predicted_class": prediction, "predicted_label": predicted_label}
# The main block to run the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)