diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py index 30b565ba..4ca5c446 100644 --- a/browsergym/core/src/browsergym/core/env.py +++ b/browsergym/core/src/browsergym/core/env.py @@ -15,16 +15,10 @@ from .action.highlevel import HighLevelActionSet from .chat import Chat from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES -from .observation import ( - MarkingError, - _post_extract, - _pre_extract, - extract_dom_extra_properties, - extract_dom_snapshot, - extract_focused_element_bid, - extract_merged_axtree, - extract_screenshot, -) +from .observation import (MarkingError, _post_extract, _pre_extract, + extract_dom_extra_properties, extract_dom_snapshot, + extract_focused_element_bid, extract_merged_axtree, + extract_mouse_position, extract_screenshot) from .spaces import AnyBox, AnyDict, Float, Unicode from .task import AbstractBrowserTask @@ -157,6 +151,9 @@ def __init__( shape=(-1, -1, 3), dtype=np.uint8, ), # swapped axes (height, width, RGB) + "mouse_position": gym.spaces.Tuple( + (Float(), Float()) + ), "dom_object": AnyDict(), "axtree_object": AnyDict(), "extra_element_properties": AnyDict(), @@ -258,12 +255,17 @@ def override_property(task, env, property): # set default timeout self.context.set_default_timeout(timeout) - # hack: keep track of the active page with a javascript callback + # hack: keep track of the active page and mouse position with javascript callbacks # there is no concept of active page in playwright # https://github.com/microsoft/playwright/issues/2603 self.context.expose_binding( "browsergym_page_activated", lambda source: self._activate_page_from_js(source["page"]) ) + self.context.expose_binding( + "browsergym_mouse_moved", lambda source: self._update_mouse_position_from_js(source) + ) + # Initialize mouse position tracking + self.last_mouse_position = None self.context.add_init_script( r""" window.browsergym_page_activated(); @@ -271,7 +273,13 @@ def override_property(task, env, property): window.addEventListener("focusin", () => {window.browsergym_page_activated();}, {capture: true}); window.addEventListener("load", () => {window.browsergym_page_activated();}, {capture: true}); window.addEventListener("pageshow", () => {window.browsergym_page_activated();}, {capture: true}); -window.addEventListener("mousemove", () => {window.browsergym_page_activated();}, {capture: true}); +window.addEventListener("mousemove", (event) => { + window.browsergym_page_activated(); + window.browsergym_mouse_moved({ + x: event.clientX, + y: event.clientY + }); +}, {capture: true}); window.addEventListener("mouseup", () => {window.browsergym_page_activated();}, {capture: true}); window.addEventListener("mousedown", () => {window.browsergym_page_activated();}, {capture: true}); window.addEventListener("wheel", () => {window.browsergym_page_activated();}, {capture: true}); @@ -485,6 +493,25 @@ def _wait_dom_loaded(self): except playwright.sync_api.Error: pass + def _update_mouse_position_from_js(self, source): + page = source["page"] + x = source["x"] + y = source["y"] + logger.debug(f"_update_mouse_position_from_js called, page={str(page)}, x={x}, y={y}") + + if not page.context == self.context: + raise RuntimeError( + f"Unexpected: mouse event from a page that belongs to a different browser context ({page})." + ) + + # Store the mouse position along with the page that received the event + self.last_mouse_position = { + "page": page, + "x": x, + "y": y, + "timestamp": time.time() + } + def _activate_page_from_js(self, page: playwright.sync_api.Page): logger.debug(f"_activate_page_from_js(page) called, page={str(page)}") if not page.context == self.context: @@ -581,6 +608,7 @@ def _get_obs(self): "last_action": self.last_action, "last_action_error": self.last_action_error, "elapsed_time": np.asarray([time.time() - self.start_time]), + "mouse_position": extract_mouse_position(self.page), } return obs diff --git a/browsergym/core/src/browsergym/core/observation.py b/browsergym/core/src/browsergym/core/observation.py index f1352660..acdd6650 100644 --- a/browsergym/core/src/browsergym/core/observation.py +++ b/browsergym/core/src/browsergym/core/observation.py @@ -146,6 +146,21 @@ def extract_screenshot(page: playwright.sync_api.Page): return img +def extract_mouse_position(page: playwright.sync_api.Page): + """ + Extracts the mouse location on a Playwright page using a hacky JS code. + + Args: + page: the playwright page of which to extract the mouse location. + + Returns: + An array of the x and y coordinates of the mouse location. + """ + position = page.evaluate("""() => { + return [window.pageX, window.pageY]; +}""") + return (position[0], position[1]) + # we could handle more data items here if needed __BID_EXPR = r"([a-zA-Z0-9]+)" diff --git a/tests/core/test_actions_highlevel.py b/tests/core/test_actions_highlevel.py index a3a4f56c..ac8c7280 100644 --- a/tests/core/test_actions_highlevel.py +++ b/tests/core/test_actions_highlevel.py @@ -12,7 +12,8 @@ # register openended gym environments import browsergym.core from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.core.action.parsers import NamedArgument, highlevel_action_parser +from browsergym.core.action.parsers import (NamedArgument, + highlevel_action_parser) from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR from browsergym.utils.obs import flatten_dom_to_str @@ -1141,6 +1142,7 @@ def get_checkbox_elem(obs): obs, reward, term, trunc, info = env.step(action) checkbox = get_checkbox_elem(obs) + assert obs['mouse_position'] == (x, y) # box not checked assert not obs["last_action_error"]