Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions agent_state.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Annotated, Sequence, TypedDict, List
from langgraph.graph import MessagesState
from typing import Annotated, List, TypedDict

from langchain_core.messages import BaseMessage


class IndicatorAgentState(TypedDict):
"""State type for the Indicator Agent including messages, input data, and analysis result."""
kline_data: Annotated[dict, "OHLCV dictionary used for computing technical indicators"]

kline_data: Annotated[
dict, "OHLCV dictionary used for computing technical indicators"
]
time_frame: Annotated[str, "time period for k line data provided"]
stock_name: Annotated[dict, "stock name for prompt"]

# Indicator Agent Tools output values (explicitly added per indicator)
rsi: Annotated[List[float], "Relative Strength Index values"]
macd: Annotated[List[float], "MACD line values"]
Expand All @@ -18,23 +21,47 @@ class IndicatorAgentState(TypedDict):
stoch_d: Annotated[List[float], "Stochastic Oscillator %D values"]
roc: Annotated[List[float], "Rate of Change values"]
willr: Annotated[List[float], "Williams %R values"]
indicator_report: Annotated[str, "Final indicator agent summary report to be used by downstream agents"]

indicator_report: Annotated[
str, "Final indicator agent summary report to be used by downstream agents"
]

# Pattern Agent
pattern_image: Annotated[str, "Base64-encoded K-line chart for pattern recognition agent use"]
pattern_image_filename: Annotated[str, "Local file path to saved K-line chart image"]
pattern_image_description: Annotated[str, "Brief description of the generated K-line image"]
pattern_report: Annotated[str, "Final pattern agent summary report to be used by downstream agents"]
pattern_image: Annotated[
str, "Base64-encoded K-line chart for pattern recognition agent use"
]
pattern_image_filename: Annotated[
str, "Local file path to saved K-line chart image"
]
pattern_image_description: Annotated[
str, "Brief description of the generated K-line image"
]
pattern_report: Annotated[
str, "Final pattern agent summary report to be used by downstream agents"
]

# Trend Agent
trend_image: Annotated[str, "Base64-encoded trend-annotated candlestick (K-line) chart for trend recognition agent use"]
trend_image_filename: Annotated[str, "Local file path to saved trendline-enhanced K-line chart image"]
trend_image_description: Annotated[str, "Brief description of the chart, including presence of support/resistance lines and visual characteristics"]
trend_report: Annotated[str, "Final trend analysis summary, describing structure, directional bias, and technical observations for downstream agents"]
trend_image: Annotated[
str,
"Base64-encoded trend-annotated candlestick (K-line) chart for trend recognition agent use",
]
trend_image_filename: Annotated[
str, "Local file path to saved trendline-enhanced K-line chart image"
]
trend_image_description: Annotated[
str,
"Brief description of the chart, including presence of support/resistance lines and visual characteristics",
]
trend_report: Annotated[
str,
"Final trend analysis summary, describing structure, directional bias, and technical observations for downstream agents",
]

# Final analysis and messaging context
analysis_results: Annotated[str, "Computed result of the analysis or decision"]
messages: Annotated[List[BaseMessage], "List of chat messages used in LLM prompt construction"]
messages: Annotated[
List[BaseMessage], "List of chat messages used in LLM prompt construction"
]
decision_prompt: Annotated[str, "decision prompt for reflection"]
final_trade_decision: Annotated[str, "Final BUY or SELL decision made after analyzing indicators"]
final_trade_decision: Annotated[
str, "Final BUY or SELL decision made after analyzing indicators"
]
24 changes: 12 additions & 12 deletions color_style.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import mplfinance as mpf

font = {
'font.family': 'sans-serif',
'font.sans-serif': ['Helvetica Neue', 'Arial', 'DejaVu Sans'],
'font.weight': 'normal',
'font.size': 15
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica Neue", "Arial", "DejaVu Sans"],
"font.weight": "normal",
"font.size": 15,
}

my_color_style = mpf.make_mpf_style(
marketcolors=mpf.make_marketcolors(
down="#A02128", # color for bullish candles
up="#006340", # color for bearish candles
edge='none', # use candle fill color for edge
wick='black', # color of the wicks
volume='in' # default volume coloring
down="#A02128", # color for bullish candles
up="#006340", # color for bearish candles
edge="none", # use candle fill color for edge
wick="black", # color of the wicks
volume="in", # default volume coloring
),
gridstyle='-',
facecolor='white', # background color
rc= font,
gridstyle="-",
facecolor="white", # background color
rc=font,
)
9 changes: 5 additions & 4 deletions decision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ def create_final_trade_decider(llm):
Create a trade decision agent node. The agent uses LLM to synthesize indicator, pattern, and trend reports
and outputs a final trade decision (LONG or SHORT) with justification and risk-reward ratio.
"""

def trade_decision_node(state) -> dict:
indicator_report = state["indicator_report"]
pattern_report = state['pattern_report']
trend_report = state['trend_report']
time_frame = state['time_frame']
stock_name = state['stock_name']
pattern_report = state["pattern_report"]
trend_report = state["trend_report"]
time_frame = state["time_frame"]
stock_name = state["stock_name"]

# --- System prompt for LLM ---
prompt = f"""You are a high-frequency quantitative trading (HFT) analyst operating on the current {time_frame} K-line chart for {stock_name}. Your task is to issue an **immediate execution order**: **LONG** or **SHORT**. ⚠️ HOLD is prohibited due to HFT constraints.
Expand Down
2 changes: 1 addition & 1 deletion default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
"agent_llm_temperature": 0.1,
"graph_llm_temperature": 0.1,
"api_key": "",
}
}
70 changes: 23 additions & 47 deletions graph_setup.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,15 @@
from typing import Dict, Any
from typing import Dict

from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from graph_util import TechnicalTools
import default_config
import os
from langchain.chat_models import init_chat_model
from langchain_core.messages import AnyMessage
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt.chat_agent_executor import AgentState
from langgraph.prebuilt import create_react_agent
from langgraph.graph import END, StateGraph, START

from typing import TypedDict, List
import random
from IPython.display import Image, display
import pandas as pd
from graph_util import *
# from langchain_community.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from agent_state import IndicatorAgentState

from indicator_agent import *
from decision_agent import *
from pattern_agent import *
from trend_agent import *
from decision_agent import create_final_trade_decider
from graph_util import TechnicalTools
from indicator_agent import create_indicator_agent
from pattern_agent import create_pattern_agent
from trend_agent import create_trend_agent


class SetGraph:
Expand All @@ -43,30 +24,32 @@ def __init__(
self.graph_llm = graph_llm
self.toolkit = toolkit
self.tool_nodes = tool_nodes

def set_graph(self):
# Create analyst nodes
# Create analyst nodes
agent_nodes = {}
tool_nodes = {}
all_agents = ['indicator', 'pattern', 'trend']
all_agents = ["indicator", "pattern", "trend"]

# create nodes for indicator agent
agent_nodes['indicator'] = create_indicator_agent(self.graph_llm, self.toolkit)
tool_nodes['indicator'] = self.tool_nodes['indicator']
agent_nodes["indicator"] = create_indicator_agent(self.graph_llm, self.toolkit)
tool_nodes["indicator"] = self.tool_nodes["indicator"]

# create nodes for pattern agent
agent_nodes['pattern'] = create_pattern_agent(self.agent_llm, self.graph_llm, self.toolkit)
tool_nodes['pattern'] = self.tool_nodes['pattern']
agent_nodes["pattern"] = create_pattern_agent(
self.agent_llm, self.graph_llm, self.toolkit
)
tool_nodes["pattern"] = self.tool_nodes["pattern"]

# create nodes for trend agent
agent_nodes['trend'] = create_trend_agent(self.agent_llm, self.graph_llm, self.toolkit)
tool_nodes['trend'] = self.tool_nodes['trend']
agent_nodes["trend"] = create_trend_agent(
self.agent_llm, self.graph_llm, self.toolkit
)
tool_nodes["trend"] = self.tool_nodes["trend"]

# create nodes for decision agent
# decision_agent_node = create_final_trade_decider(self.agent_llm)
decision_agent_node = create_final_trade_decider(self.graph_llm)


# create graph
graph = StateGraph(IndicatorAgentState)

Expand All @@ -75,31 +58,24 @@ def set_graph(self):
graph.add_node(f"{agent_type.capitalize()} Agent", cur_node)
graph.add_node(f"{agent_type}_tools", tool_nodes[agent_type])


# add rest of the nodes
graph.add_node("Decision Maker", decision_agent_node)

# set start of graph
graph.add_edge(START, "Indicator Agent")


# add edges to graph
for i, agent_type in enumerate(all_agents):
current_agent = f"{agent_type.capitalize()} Agent"
current_tools = f"{agent_type}_tools"


if i == len(all_agents) - 1:
graph.add_edge(current_agent, "Decision Maker")
else:

next_agent = f"{all_agents[i + 1].capitalize()} Agent"
graph.add_edge(current_agent, next_agent)



# Decision Maker Process
graph.add_edge("Decision Maker", END)


return graph.compile()
Loading