diff --git a/dune-project b/dune-project index 5c5ac7ca2c9..e09a0911928 100644 --- a/dune-project +++ b/dune-project @@ -52,6 +52,13 @@ (name tgroup) (depends xapi-log xapi-stdext-unix)) +(package + (name rate-limit) + (synopsis "Simple token bucket-based rate-limiting") + (depends + (ocaml (>= 4.12)) + xapi-log xapi-stdext-unix)) + (package (name xml-light2)) diff --git a/ocaml/idl/datamodel.ml b/ocaml/idl/datamodel.ml index 983c15f9c77..5cb36121e44 100644 --- a/ocaml/idl/datamodel.ml +++ b/ocaml/idl/datamodel.ml @@ -10545,6 +10545,7 @@ let all_system = ; Datamodel_vm_group.t ; Datamodel_host_driver.t ; Datamodel_driver_variant.t + ; Datamodel_rate_limit.t ] (* If the relation is one-to-many, the "many" nodes (one edge each) must come before the "one" node (many edges) *) @@ -10796,6 +10797,7 @@ let expose_get_all_messages_for = ; _observer ; _host_driver ; _driver_variant + ; _rate_limit ] let no_task_id_for = [_task; (* _alert; *) _event] @@ -11152,6 +11154,10 @@ let http_actions = ; ("put_bundle", (Put, Constants.put_bundle_uri, true, [], _R_POOL_OP, [])) ] +(* Actions that incorporate the rate limiter from Xapi_rate_limiting within their handler + For now, just RPC calls *) +let custom_rate_limit_http_actions = ["post_root"; "post_RPC2"; "post_jsonrpc"] + (* these public http actions will NOT be checked by RBAC *) (* they are meant to be used in exceptional cases where RBAC is already *) (* checked inside them, such as in the XMLRPC (API) calls *) diff --git a/ocaml/idl/datamodel_common.ml b/ocaml/idl/datamodel_common.ml index bb8413396ee..f6d6c19f5b8 100644 --- a/ocaml/idl/datamodel_common.ml +++ b/ocaml/idl/datamodel_common.ml @@ -10,7 +10,7 @@ open Datamodel_roles to leave a gap for potential hotfixes needing to increment the schema version.*) let schema_major_vsn = 5 -let schema_minor_vsn = 793 +let schema_minor_vsn = 794 (* Historical schema versions just in case this is useful later *) let rio_schema_major_vsn = 5 @@ -315,6 +315,8 @@ let _host_driver = "Host_driver" let _driver_variant = "Driver_variant" +let _rate_limit = "Rate_limit" + let update_guidances = Enum ( "update_guidances" diff --git a/ocaml/idl/datamodel_lifecycle.ml b/ocaml/idl/datamodel_lifecycle.ml index 8cfdf21cef2..d59198f6941 100644 --- a/ocaml/idl/datamodel_lifecycle.ml +++ b/ocaml/idl/datamodel_lifecycle.ml @@ -1,4 +1,6 @@ let prototyped_of_class = function + | "Rate_limit" -> + Some "25.39.0" | "Driver_variant" -> Some "25.2.0" | "Host_driver" -> @@ -13,6 +15,16 @@ let prototyped_of_class = function None let prototyped_of_field = function + | "Rate_limit", "fill_rate" -> + Some "25.39.0" + | "Rate_limit", "burst_size" -> + Some "25.39.0" + | "Rate_limit", "host_ip" -> + Some "26.1.0" + | "Rate_limit", "user_agent" -> + Some "26.1.0" + | "Rate_limit", "uuid" -> + Some "25.39.0" | "Driver_variant", "status" -> Some "25.2.0" | "Driver_variant", "priority" -> @@ -138,11 +150,11 @@ let prototyped_of_field = function | "VM_guest_metrics", "netbios_name" -> Some "24.28.0" | "VM_metrics", "numa_node_memory" -> - Some "26.1.0-next" + Some "26.2.0" | "VM_metrics", "numa_nodes" -> - Some "26.1.0-next" + Some "26.2.0" | "VM_metrics", "numa_optimised" -> - Some "26.1.0-next" + Some "26.2.0" | "VM", "groups" -> Some "24.19.1" | "VM", "pending_guidances_full" -> @@ -197,6 +209,10 @@ let prototyped_of_field = function None let prototyped_of_message = function + | "Rate_limit", "destroy" -> + Some "26.1.0" + | "Rate_limit", "create" -> + Some "26.1.0" | "Driver_variant", "select" -> Some "25.2.0" | "Host_driver", "rescan" -> diff --git a/ocaml/idl/datamodel_rate_limit.ml b/ocaml/idl/datamodel_rate_limit.ml new file mode 100644 index 00000000000..e4feb257c81 --- /dev/null +++ b/ocaml/idl/datamodel_rate_limit.ml @@ -0,0 +1,69 @@ +(* + * Copyright (C) 2023 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +open Datamodel_types +open Datamodel_common +open Datamodel_roles + +let lifecycle = [] + +let create = + call ~name:"create" ~lifecycle + ~params: + [ + ( String + , "user_agent" + , "User agent of the rate limited client. Set to the empty string to \ + rate limit all user agents." + ) + ; ( String + , "host_ip" + , "IP address of the rate limited client. Set to empty string to rate \ + limit all addresses." + ) + ; (Float, "burst_size", "Amount of tokens that can be consumed at once") + ; (Float, "fill_rate", "Amount of tokens added to the bucket every second") + ] + ~doc:"Create a new rate limiter for a given client" + ~allowed_roles:_R_POOL_OP () + ~result:(Ref _rate_limit, "The reference of the created rate limit.") + +let destroy = + call ~name:"destroy" ~lifecycle + ~params:[(Ref _rate_limit, "self", "The rate limiter to destroy")] + ~doc:"Destroy a rate limiter" ~allowed_roles:_R_POOL_OP () + +let t = + create_obj ~name:_rate_limit ~descr:"Rate limiting policy for a XAPI client" + ~doccomments:[] ~gen_constructor_destructor:false ~gen_events:true + ~in_db:true ~lifecycle ~persist:PersistEverything ~in_oss_since:None + ~messages_default_allowed_roles:_R_POOL_ADMIN + ~contents: + ([uid _rate_limit ~lifecycle] + @ [ + field ~qualifier:StaticRO ~ty:String ~lifecycle "user_agent" + "User agent of the rate limited client" ~ignore_foreign_key:true + ~default_value:(Some (VString "")) + ; field ~qualifier:StaticRO ~ty:String ~lifecycle "host_ip" + "IP address of the rate limited client" ~ignore_foreign_key:true + ~default_value:(Some (VString "")) + ; field ~qualifier:StaticRO ~ty:Float ~lifecycle "burst_size" + "Amount of tokens that can be consumed at once" + ~ignore_foreign_key:true ~default_value:(Some (VFloat 0.)) + ; field ~qualifier:StaticRO ~ty:Float ~lifecycle "fill_rate" + "Tokens added to token bucket per second" ~ignore_foreign_key:true + ~default_value:(Some (VFloat 0.)) + ] + ) + ~messages:[create; destroy] () diff --git a/ocaml/idl/dune b/ocaml/idl/dune index ac591ae1e0f..eb55c786d40 100644 --- a/ocaml/idl/dune +++ b/ocaml/idl/dune @@ -7,7 +7,7 @@ datamodel_values datamodel_schema datamodel_certificate datamodel_diagnostics datamodel_repository datamodel_lifecycle datamodel_vtpm datamodel_observer datamodel_vm_group api_version - datamodel_host_driver datamodel_driver_variant) + datamodel_host_driver datamodel_driver_variant datamodel_rate_limit) (libraries rpclib.core sexplib0 diff --git a/ocaml/idl/schematest.ml b/ocaml/idl/schematest.ml index 6bd5ee9ae36..3db2e237bf4 100644 --- a/ocaml/idl/schematest.ml +++ b/ocaml/idl/schematest.ml @@ -3,7 +3,7 @@ let hash x = Digest.string x |> Digest.to_hex (* BEWARE: if this changes, check that schema has been bumped accordingly in ocaml/idl/datamodel_common.ml, usually schema_minor_vsn *) -let last_known_schema_hash = "08510322cf77e8ba10082f2e611ebb40" +let last_known_schema_hash = "5eaddc1deda9c863deadef0d72b1b82e" let current_schema_hash : string = let open Datamodel_types in diff --git a/ocaml/libs/rate-limit/bucket_table.ml b/ocaml/libs/rate-limit/bucket_table.ml new file mode 100644 index 00000000000..fc2cb702423 --- /dev/null +++ b/ocaml/libs/rate-limit/bucket_table.ml @@ -0,0 +1,259 @@ +(* + * Copyright (C) 2025 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +module D = Debug.Make (struct let name = "bucket_table" end) + +type rate_limit_data = { + bucket: Token_bucket.t + ; process_queue: + (float * (unit -> unit)) Queue.t (* contains token cost and callback *) + ; process_queue_lock: Mutex.t + ; worker_thread_cond: Condition.t + ; should_terminate: bool ref (* signal termination to worker thread *) + ; worker_thread: Thread.t +} + +module Key = struct + type t = {user_agent: string; host_ip: string} + + let equal a b = a.user_agent = b.user_agent && a.host_ip = b.host_ip + + (** Empty string acts as wildcard, matching any value *) + let matches ~pattern ~target = + (pattern.user_agent = "" || pattern.user_agent = target.user_agent) + && (pattern.host_ip = "" || pattern.host_ip = target.host_ip) + + (** Priority for matching: exact (0) > host_ip only (1) > user_agent only (2) *) + let compare_wildcard k = + ( if k.user_agent = "" then + 2 + else + 0 + ) + + + if k.host_ip = "" then + 1 + else + 0 + + let is_all_wildcard k = k.user_agent = "" && k.host_ip = "" + + (** Total order: fewer wildcards first, then lexicographic by fields *) + let compare a b = + match compare (compare_wildcard a) (compare_wildcard b) with + | 0 -> ( + match String.compare a.user_agent b.user_agent with + | 0 -> + String.compare a.host_ip b.host_ip + | n -> + n + ) + | n -> + n +end + +type cached_table = { + table: (Key.t * rate_limit_data) list + ; cache: (Key.t, rate_limit_data option) Lru.t +} + +type t = cached_table Atomic.t + +let with_lock = Xapi_stdext_threads.Threadext.Mutex.execute + +let create () = Atomic.make {table= []; cache= Lru.create 100} + +(** Find the best matching entry for a client_id. + List is pre-sorted by Key.compare (most specific first), so first match wins. + Priority: exact match > host_ip specified > user_agent specified *) +let find_match {table; cache} ~client_id = + let entry_opt = Lru.lookup cache client_id in + match entry_opt with + | Some result -> + result + | None -> + let result = + Option.map snd + (List.find_opt + (fun (key, _) -> Key.matches ~pattern:key ~target:client_id) + table + ) + in + if Lru.add cache client_id result then Lru.trim cache ; + result + +let mem t ~client_id = + let entries = Atomic.get t in + Option.is_some (find_match entries ~client_id) + +(* The worker thread is responsible for calling the callback when the token + amount becomes available *) +let rec worker_loop ~bucket ~process_queue ~process_queue_lock + ~worker_thread_cond ~should_terminate = + let process_item cost callback = + Token_bucket.delay_then_consume bucket cost ; + callback () + in + let item_opt = + with_lock process_queue_lock (fun () -> + while Queue.is_empty process_queue && not !should_terminate do + Condition.wait worker_thread_cond process_queue_lock + done ; + Queue.take_opt process_queue + ) + in + match item_opt with + | None -> + (* Queue is empty only when termination was signalled *) + () + | Some (cost, callback) -> + process_item cost callback ; + worker_loop ~bucket ~process_queue ~process_queue_lock ~worker_thread_cond + ~should_terminate + +(* TODO: Indicate failure reason - did we get invalid config or try to add an + already present client_id? *) +let add_bucket t ~client_id ~burst_size ~fill_rate = + if Key.is_all_wildcard client_id then + false + (* Reject keys with both fields empty *) + else + let {table; _} = Atomic.get t in + if List.exists (fun (key, _) -> Key.equal key client_id) table then + false + else + match Token_bucket.create ~burst_size ~fill_rate with + | Some bucket -> + let process_queue = Queue.create () in + let process_queue_lock = Mutex.create () in + let worker_thread_cond = Condition.create () in + let should_terminate = ref false in + let worker_thread = + Thread.create + (fun () -> + worker_loop ~bucket ~process_queue ~process_queue_lock + ~worker_thread_cond ~should_terminate + ) + () + in + let data = + { + bucket + ; process_queue + ; process_queue_lock + ; worker_thread_cond + ; should_terminate + ; worker_thread + } + in + Atomic.set t + { + table= + List.sort + (fun (k1, _) (k2, _) -> Key.compare k1 k2) + ((client_id, data) :: table) + ; cache= Lru.create 100 + } ; + true + | None -> + false + +let delete_bucket t ~client_id = + let {table; _} = Atomic.get t in + match List.find_opt (fun (key, _) -> Key.equal key client_id) table with + | None -> + () + | Some (_, data) -> + with_lock data.process_queue_lock (fun () -> + data.should_terminate := true ; + Condition.signal data.worker_thread_cond + ) ; + Thread.join data.worker_thread ; + Atomic.set t + { + table= + List.filter (fun (key, _) -> not (Key.equal key client_id)) table + ; cache= Lru.create 100 + } + +let try_consume t ~client_id amount = + let entries = Atomic.get t in + match find_match entries ~client_id with + | None -> + false + | Some data -> + Token_bucket.consume data.bucket amount + +let peek t ~client_id = + let entries = Atomic.get t in + Option.map + (fun data -> Token_bucket.peek data.bucket) + (find_match entries ~client_id) + +(* The callback should return quickly - if it is a longer task it is + responsible for creating a thread to do the task *) +let submit t ~client_id ~callback amount = + let entries = Atomic.get t in + match find_match entries ~client_id with + | None -> + D.debug "Found no rate limited client_id, returning" ; + callback () + | Some + ( {bucket; process_queue; process_queue_lock; worker_thread_cond; _} as + _data + ) -> + let run_immediately = + with_lock process_queue_lock (fun () -> + let immediate = + Queue.is_empty process_queue && Token_bucket.consume bucket amount + in + if not immediate then Queue.add (amount, callback) process_queue ; + Condition.signal worker_thread_cond ; + immediate + ) + in + if run_immediately then callback () + +(* Block and execute on the same thread *) +let submit_sync t ~client_id ~callback amount = + let entries = Atomic.get t in + match find_match entries ~client_id with + | None -> + callback () + | Some bucket_data -> ( + let channel_opt = + with_lock bucket_data.process_queue_lock (fun () -> + if + Queue.is_empty bucket_data.process_queue + && Token_bucket.consume bucket_data.bucket amount + then + None + (* Can run callback immediately after releasing lock *) + else + (* Rate limited, need to retrieve function result via channel *) + let channel = Event.new_channel () in + Queue.add + (amount, fun () -> Event.sync (Event.send channel ())) + bucket_data.process_queue ; + Condition.signal bucket_data.worker_thread_cond ; + Some channel + ) + in + match channel_opt with + | None -> + callback () + | Some channel -> + Event.sync (Event.receive channel) ; + callback () + ) diff --git a/ocaml/libs/rate-limit/bucket_table.mli b/ocaml/libs/rate-limit/bucket_table.mli new file mode 100644 index 00000000000..a3a810f3d05 --- /dev/null +++ b/ocaml/libs/rate-limit/bucket_table.mli @@ -0,0 +1,67 @@ +(* + * Copyright (C) 2025 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +(** Key type for bucket table lookups. Empty strings act as wildcards. *) +module Key : sig + type t = {user_agent: string; host_ip: string} + + val equal : t -> t -> bool + + val matches : pattern:t -> target:t -> bool + (** [matches ~pattern ~target] returns true if [pattern] matches [target]. + Empty strings in [pattern] act as wildcards matching any value. *) + + val compare : t -> t -> int + (** Total order: fewer wildcards first, then lexicographic by fields. *) +end + +(** List of entries mapping keys to their token buckets for rate limiting. + Lookups use wildcard matching with priority: exact > host_ip only > user_agent only. *) +type t + +val create : unit -> t +(** [create ()] creates a new empty bucket table. *) + +val add_bucket : + t -> client_id:Key.t -> burst_size:float -> fill_rate:float -> bool +(** [add_bucket table ~client_id ~burst_size ~fill_rate] adds a token bucket + for the given client_id. Returns [false] if a bucket already exists, if + the bucket configuration is invalid (e.g. negative/zero fill rate), or if + client_id has both fields empty (all-wildcard keys are rejected). *) + +val mem : t -> client_id:Key.t -> bool +(** [mem table ~client_id] returns whether [client_id] matches any entry + in the bucket table using wildcard matching. *) + +val peek : t -> client_id:Key.t -> float option +(** [peek table ~client_id] returns the current token count for the client_id, + or [None] if no bucket exists. *) + +val delete_bucket : t -> client_id:Key.t -> unit +(** [delete_bucket table ~client_id] removes the bucket for the client_id. *) + +val try_consume : t -> client_id:Key.t -> float -> bool +(** [try_consume table ~client_id amount] attempts to consume tokens. + Returns [true] on success, [false] if insufficient tokens. *) + +val submit : t -> client_id:Key.t -> callback:(unit -> unit) -> float -> unit +(** [submit table ~client_id ~callback amount] submits a callback to be executed + under rate limiting. If tokens are immediately available and no callbacks are + queued, the callback runs synchronously. Otherwise, it is enqueued and will + be executed by a worker thread when tokens become available. Returns immediately. *) + +val submit_sync : t -> client_id:Key.t -> callback:(unit -> 'a) -> float -> 'a +(** [submit_sync table ~client_id ~callback amount] submits a callback to be + executed under rate limiting and blocks until it completes, returning the + callback's result. *) diff --git a/ocaml/libs/rate-limit/dune b/ocaml/libs/rate-limit/dune new file mode 100644 index 00000000000..3436c398228 --- /dev/null +++ b/ocaml/libs/rate-limit/dune @@ -0,0 +1,7 @@ +(library + (name rate_limit) + (public_name rate-limit) + (libraries threads.posix mtime mtime.clock.os xapi-log xapi-stdext-threads clock) +) + + diff --git a/ocaml/libs/rate-limit/lru.ml b/ocaml/libs/rate-limit/lru.ml new file mode 100644 index 00000000000..4df525143d0 --- /dev/null +++ b/ocaml/libs/rate-limit/lru.ml @@ -0,0 +1,288 @@ +(* + * This module implements a cache with support for a least-recently-used + * (LRU) replacement policy. Main features: + * + * Implemented with standard OCaml data types, no dependency on outside + * libararies for the main functionality. + * + * Implemented with mutable state. This keeps the implementation compact + * and efficient but requires thinking about state more. + * + * The architecture is: elements are kept in a hash table to look them + * up based on a key. Additionally they are kept in a doubly linked + * list. The head of the list is the least recently used element that + * can be dropped to make room in the cache. When an element is found in + * the cache it is moved to the tail of the linked list. + *) + +let invalid_arg fmt = Printf.ksprintf invalid_arg fmt + +let _fail fmt = Printf.ksprintf failwith fmt + +module LL : sig + (** Doubly linked list ['a t] holding elements of type ['a]. *) + + (** doubly linked list; this is a cyclic data structure; don't use [=] + on it as it may not terminate. *) + type 'a t + + (** a node in the list. A node can be removed from its list. Don't use + [=] on [node] values as it may not terminate. *) + type 'a node + + val create : unit -> 'a t + (** create an empty list *) + + val node : 'a -> 'a node + (** create a node to carry a value *) + + val value : 'a node -> 'a + (** obtain the value from a node *) + + val append : 'a t -> 'a node -> unit + (** append a node at the end *) + + val drop : 'a t -> 'a node -> unit + (** [drop t n] a node [n] from list [t]. It is an unchecked error to + pass a node [n] to [drop] that is not an element of [t] to begin + with.*) + + val first : 'a t -> 'a node option + (** first/head node of the list *) + + val last : 'a t -> 'a node option + (** last/tail node of the list *) + + val foldl : ('a -> 'b -> 'a) -> 'a -> 'b t -> 'a + (** fold from head *) + + val foldr : ('a -> 'b -> 'b) -> 'a t -> 'b -> 'b + (** fold from tail *) + + val to_list : 'a t -> 'a list + (** retrieve all elements from the list *) + + val from_list : 'a list -> 'a t + (** construct a [t] value from list *) +end = struct + type 'a node = { + value: 'a + ; mutable prev: 'a node option + ; mutable next: 'a node option + } + + type 'a t = {mutable first: 'a node option; mutable last: 'a node option} + + let create () = {first= None; last= None} + + let node x = {value= x; prev= None; next= None} + + let append t n = + match t.last with + | None -> + let node = Some n in + t.first <- node ; + t.last <- node + | Some lst -> + let node = Some n in + lst.next <- node ; + n.prev <- t.last ; + t.last <- node + + (** [drop] a node [n] from (its) list [t]. The interesting property is + that we can drop any element from its list that we know. However, + we don't check that [n] is indeed a member of [t] and it's an + unchecked error to pass an [n] that is not a member of [t]. + + This is similar to a + pointer-based implementation in C. We infer that we need to update + the fist, last entry of the list of [n]'s prev or next is [None], + hence it is the first or last element in the list. *) + let drop t n = + let np = n.prev in + let nn = n.next in + ( match np with + | None -> + t.first <- nn + | Some x -> + x.next <- nn ; + n.prev <- None + ) ; + match nn with + | None -> + t.last <- np + | Some x -> + x.prev <- np ; + n.next <- None + + let first t = t.first + + let last t = t.last + + let value node = node.value + + let foldl f zero t = + let rec loop acc = function + | None -> + acc + | Some n -> + loop (f acc n.value) n.next + in + loop zero t.first + + let foldr f t zero = + let rec loop acc = function + | None -> + acc + | Some n -> + loop (f n.value acc) n.prev + in + loop zero t.last + + let to_list t = foldr (fun x xs -> x :: xs) t [] + + let from_list xs = + let t = create () in + List.iter (fun x -> append t (node x)) xs ; + t +end + +(** A store for key/value pairs of type ['k] and ['v]. The main store + is [table] that maps a key to a key/value node. Every item in the + store is also a member in the [queue]. The [queue] keeps track of + which elements are looked up most often. The first element in the + [queue] is the least used one. *) +type ('k, 'v) t = { + table: ('k, ('k * 'v) LL.node) Hashtbl.t + ; queue: ('k * 'v) LL.t + ; cap: int (** max capacity of table and queue *) + ; mutable entries: int (** actual capacity of table and queue *) + ; lock: Mutex.t (** lock while operating on this value *) +} + +let locked m f = + let finally () = Mutex.unlock m in + Mutex.lock m ; Fun.protect ~finally f + +(* All primed functions below are not thread safe because they are + manipulating state; we will use a lock to protect against concurrent + update. However, we have to do that on an outer layer such that we + can use these functions internally after we obtained the lock. *) + +module Unsafe = struct + let create' capacity = + if capacity <= 0 then + invalid_arg "%s: capacity needs to be postive" __FUNCTION__ ; + { + table= Hashtbl.create capacity + ; queue= LL.create () + ; cap= capacity + ; entries= 0 + ; lock= Mutex.create () + } + + let size' t = t.entries + + let cap' t = t.cap + + let to_list' t = LL.to_list t.queue + + (** [lookup] an entry based on its [key]; this may fail or succeeed. + In the success case, the entry is moved to the tail of the + [queue]. Hnece, the least-used entry is at the front. *) + let lookup' t key = + match Hashtbl.find_opt t.table key with + | Some v -> + LL.drop t.queue v ; + LL.append t.queue v ; + Some (LL.value v |> snd) + | None -> + None + + (** [remove] an entry based on this [key] *) + let remove' t key = + match Hashtbl.find_opt t.table key with + | Some v -> + LL.drop t.queue v ; + Hashtbl.remove t.table key ; + t.entries <- t.entries - 1 + | None -> + () + + (** [add] a new entry; do nothing if the entry exists. If the new + entry exceeds the capacity of the table, [true] + is returned and [false] otherwise. It signals the caller to [trim] + the table.*) + let add' t key value = + match lookup' t key with + | None -> + let node = LL.node (key, value) in + Hashtbl.add t.table key node ; + t.entries <- t.entries + 1 ; + LL.append t.queue node ; + t.entries > t.cap + | Some _ -> + t.entries > t.cap + + (** [lru] returns the least-recently-used key/value pair *) + let lru' t = LL.first t.queue |> Option.map LL.value + + (** [drop_while] drops elements starting in least-recently-used order + while predicate [evict] is true. The predicate receives the key/value + and a boolean that indicates if the cache is over capacity. If + [evict] returns true it can perform any finalisation on the value + before it will be removed by [drop_while]. *) + + let rec drop_while' t ~evict = + match lru' t with + | Some ((key, _) as kv) when evict kv (t.entries > t.cap) -> + remove' t key ; drop_while' t ~evict + | Some _ -> + () + | None -> + () + + (** [trim] the table such that it does not exceed its capacity by + removing the least-used element repeatedly until this is achieved. + If finalisation of values is required, use [drop_while] because + [trim] does not provide it. *) + + let trim' t = + let evict _ x = x in + drop_while' t ~evict +end + +(* Functions below are intended to be used by clients of this modules. + They have to take care of locking. *) + +let create = Unsafe.create' + +let size t = locked t.lock @@ fun () -> Unsafe.size' t + +let cap t = locked t.lock @@ fun () -> Unsafe.cap' t + +let to_list t = locked t.lock @@ fun () -> Unsafe.to_list' t + +let lookup t = locked t.lock @@ fun () -> Unsafe.lookup' t + +let remove t = locked t.lock @@ fun () -> Unsafe.remove' t + +let add t = locked t.lock @@ fun () -> Unsafe.add' t + +let drop_while t = locked t.lock @@ fun () -> Unsafe.drop_while' t + +let trim t = locked t.lock @@ fun () -> Unsafe.trim' t + +(* + * Copyright (C) 2023 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) diff --git a/ocaml/libs/rate-limit/lru.mli b/ocaml/libs/rate-limit/lru.mli new file mode 100644 index 00000000000..3c8dc9def7b --- /dev/null +++ b/ocaml/libs/rate-limit/lru.mli @@ -0,0 +1,124 @@ +(* + * This module implements a cache with support for a least-recently-used + * (LRU) replacement policy. Main features: + * + * Implemented with standard OCaml data types, no dependency on outside + * libararies for the main functionality. + * + * Implemented with mutable state. This keeps the implementation compact + * and efficient but requires thinking about state more. + * + * The architecture is: elements are kept in a hash table to look them + * up based on a key. Additionally they are kept in a doubly linked + * list. The head of the list is the least recently used element that + * can be dropped to make room in the cache. When an element is found in + * the cache it is moved to the tail of the linked list. + *) + +(** A key/value store with a size cap and support to remove the + least-used element to make room for new entries. *) +type ('k, 'v) t + +val create : int -> ('a, 'b) t +(** [create] and empty [LRU] for a given (positive) size. *) + +val size : ('a, 'b) t -> int +(** current number of entries *) + +val cap : ('a, 'b) t -> int +(** max number of entries (from when LRU was created) *) + +val lookup : ('a, 'b) t -> 'a -> 'b option +(** [lookup] an entry. Returns [None] if it is not found. *) + +val remove : ('a, 'b) t -> 'a -> unit +(** [remove] an entry based on its key *) + +val add : ('a, 'b) t -> 'a -> 'b -> bool +(** [add] a new entry; if the entry already exists the entry is not + added. The reason is that we want to avoid removing entries without + clients being able to act on them. Returns [true] if the capacity of + the cache is exceeded and should be [trim]'ed. *) + +val drop_while : ('a, 'b) t -> evict:('a * 'b -> bool -> bool) -> unit +(** [drop_while] drops elements starting in least-recently-used + while predicate [evict] is true. The predicate receives the key/value + and a boolean that indicates if the cache is over capacity. If + [evict] returns true it can perform any finalisation on the value + before it will be removed by [drop_while]. [drop_while] can be used + to clean the cache or to remove elements on any criteria. + + The [evict] function must not call any of the functions of this API + but decide purely on the value it is passed *) + +val trim : ('a, 'b) t -> unit +(** [trim] the cache by removing least-used elements until the size does + not exceed the capacity. No finalisation of elements that are + removed. Use [drop_while] for custom finalisation. *) + +val to_list : ('a, 'b) t -> ('a * 'b) list +(** retrieve all elements as a list. The head of the list is in LRU + order: the least-used entry is at the head of the list. *) + +module LL : sig + (** Doubly linked list ['a t] holding elements of type ['a].I t is + only exposed here to facilitate testing. It should not be used by + outside code *) + + (** doubly linked list; this is a cyclic data structure; don't use [=] + on it as it may not terminate. *) + type 'a t + + (** a node in the list. A node can be removed from its list. Don't use + [=] on [node] values as it may not terminate. *) + type 'a node + + val create : unit -> 'a t + (** create an empty list *) + + val node : 'a -> 'a node + (** create a node to carry a value *) + + val value : 'a node -> 'a + (** obtain the value from a node *) + + val append : 'a t -> 'a node -> unit + (** append a node at the end *) + + val drop : 'a t -> 'a node -> unit + (** [drop t n] a node [n] from list [t]. It is an unchecked error to + pass a node [n] to [drop] that is not an element of [t] to begin + with.*) + + val first : 'a t -> 'a node option + (** first/head node of the list *) + + val last : 'a t -> 'a node option + (** last/tail node of the list *) + + val foldl : ('a -> 'b -> 'a) -> 'a -> 'b t -> 'a + (** fold from head *) + + val foldr : ('a -> 'b -> 'b) -> 'a t -> 'b -> 'b + (** fold from tail *) + + val to_list : 'a t -> 'a list + (** retrieve all elements from the list *) + + val from_list : 'a list -> 'a t + (** construct a [t] value from list *) +end + +(* + * Copyright (C) 2023 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) diff --git a/ocaml/libs/rate-limit/test/dune b/ocaml/libs/rate-limit/test/dune new file mode 100644 index 00000000000..49e9f54896a --- /dev/null +++ b/ocaml/libs/rate-limit/test/dune @@ -0,0 +1,4 @@ +(tests + (names test_token_bucket test_bucket_table test_lru) + (package rate-limit) + (libraries rate_limit alcotest qcheck-core qcheck-alcotest mtime mtime.clock.os fmt xapi-log threads.posix)) diff --git a/ocaml/libs/rate-limit/test/test_bucket_table.ml b/ocaml/libs/rate-limit/test/test_bucket_table.ml new file mode 100644 index 00000000000..3024a2976ae --- /dev/null +++ b/ocaml/libs/rate-limit/test/test_bucket_table.ml @@ -0,0 +1,626 @@ +open Rate_limit + +(* Helper to create a Key.t from a string for convenience in tests *) +let key s = Bucket_table.Key.{user_agent= s; host_ip= ""} + +let test_create () = + let table = Bucket_table.create () in + Alcotest.(check (option (float 0.0))) + "Empty table returns None for peek" None + (Bucket_table.peek table ~client_id:{user_agent= "test"; host_ip= ""}) + +let test_add_bucket () = + let table = Bucket_table.create () in + let success = + Bucket_table.add_bucket table + ~client_id:{user_agent= "test"; host_ip= ""} + ~burst_size:10.0 ~fill_rate:2.0 + in + Alcotest.(check bool) "Adding valid bucket should succeed" true success ; + Alcotest.(check (option (float 0.1))) + "Peek should return burst_size" (Some 10.0) + (Bucket_table.peek table ~client_id:{user_agent= "test"; host_ip= ""}) + +let test_add_bucket_invalid () = + let table = Bucket_table.create () in + let success = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:0.0 + in + Alcotest.(check bool) + "Adding bucket with zero fill rate should fail" false success ; + let success_neg = + Bucket_table.add_bucket table ~client_id:(key "agent2") ~burst_size:10.0 + ~fill_rate:(-1.0) + in + Alcotest.(check bool) + "Adding bucket with negative fill rate should fail" false success_neg + +let test_delete_bucket () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:2.0 + in + Alcotest.(check (option (float 0.1))) + "Bucket exists before delete" (Some 10.0) + (Bucket_table.peek table ~client_id:(key "agent1")) ; + Bucket_table.delete_bucket table ~client_id:(key "agent1") ; + Alcotest.(check (option (float 0.0))) + "Bucket removed after delete" None + (Bucket_table.peek table ~client_id:(key "agent1")) + +let test_delete_nonexistent () = + let table = Bucket_table.create () in + Bucket_table.delete_bucket table ~client_id:(key "nonexistent") ; + Alcotest.(check pass) "Deleting nonexistent bucket should not raise" () () + +let test_try_consume () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:2.0 + in + let success = Bucket_table.try_consume table ~client_id:(key "agent1") 3.0 in + Alcotest.(check bool) "Consuming available tokens should succeed" true success ; + Alcotest.(check (option (float 0.1))) + "Tokens reduced after consume" (Some 7.0) + (Bucket_table.peek table ~client_id:(key "agent1")) + +let test_try_consume_insufficient () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:5.0 + ~fill_rate:1.0 + in + let success = Bucket_table.try_consume table ~client_id:(key "agent1") 10.0 in + Alcotest.(check bool) + "Consuming more than available should fail" false success ; + Alcotest.(check (option (float 0.1))) + "Tokens unchanged after failed consume" (Some 5.0) + (Bucket_table.peek table ~client_id:(key "agent1")) + +let test_try_consume_nonexistent () = + let table = Bucket_table.create () in + let success = + Bucket_table.try_consume table ~client_id:(key "nonexistent") 1.0 + in + Alcotest.(check bool) + "Consuming from nonexistent bucket should fail" false success + +let test_peek_nonexistent () = + let table = Bucket_table.create () in + Alcotest.(check (option (float 0.0))) + "Peek nonexistent bucket returns None" None + (Bucket_table.peek table ~client_id:(key "nonexistent")) + +let test_multiple_agents () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:2.0 + in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent2") ~burst_size:20.0 + ~fill_rate:5.0 + in + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 5.0 in + Alcotest.(check (option (float 0.1))) + "Agent1 tokens reduced" (Some 5.0) + (Bucket_table.peek table ~client_id:(key "agent1")) ; + Alcotest.(check (option (float 0.1))) + "Agent2 tokens unchanged" (Some 20.0) + (Bucket_table.peek table ~client_id:(key "agent2")) + +let test_submit () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:10.0 + in + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 10.0 in + let executed = ref false in + let start_counter = Mtime_clock.counter () in + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> executed := true) + 5.0 ; + let elapsed_span = Mtime_clock.count start_counter in + let elapsed_seconds = Mtime.Span.to_float_ns elapsed_span *. 1e-9 in + (* submit should return immediately (non-blocking) *) + Alcotest.(check bool) "submit returns immediately" true (elapsed_seconds < 0.1) ; + (* Wait for callback to be executed by worker *) + Thread.delay 0.6 ; + Alcotest.(check bool) "callback eventually executed" true !executed + +let test_submit_nonexistent () = + let table = Bucket_table.create () in + let executed = ref false in + Bucket_table.submit table ~client_id:(key "nonexistent") + ~callback:(fun () -> executed := true) + 1.0 ; + Alcotest.(check bool) + "submit on nonexistent bucket runs callback immediately" true !executed + +let test_submit_fairness () = + (* Test that callbacks are executed in FIFO order regardless of token cost *) + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:5.0 + ~fill_rate:5.0 + in + (* Drain the bucket *) + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 5.0 in + let execution_order = ref [] in + let order_mutex = Mutex.create () in + let record_execution id = + Mutex.lock order_mutex ; + execution_order := id :: !execution_order ; + Mutex.unlock order_mutex + in + (* Submit callbacks with varying costs - order should be preserved *) + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 1) + 1.0 ; + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 2) + 3.0 ; + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 3) + 1.0 ; + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 4) + 2.0 ; + (* Wait for all callbacks to complete (total cost = 7 tokens, rate = 5/s) *) + Thread.delay 2.0 ; + let order = List.rev !execution_order in + Alcotest.(check (list int)) + "callbacks execute in FIFO order" [1; 2; 3; 4] order + +let test_submit_sync () = + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:10.0 + ~fill_rate:10.0 + in + (* Test 1: Returns callback result immediately when tokens available *) + let result = + Bucket_table.submit_sync table ~client_id:(key "agent1") + ~callback:(fun () -> 42) + 5.0 + in + Alcotest.(check int) "returns callback result" 42 result ; + (* Test 2: Blocks and waits for tokens, then returns result *) + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 5.0 in + (* drain bucket *) + let start_counter = Mtime_clock.counter () in + let result2 = + Bucket_table.submit_sync table ~client_id:(key "agent1") + ~callback:(fun () -> "hello") + 5.0 + in + let elapsed_span = Mtime_clock.count start_counter in + let elapsed_seconds = Mtime.Span.to_float_ns elapsed_span *. 1e-9 in + Alcotest.(check string) "returns string result" "hello" result2 ; + Alcotest.(check bool) + "blocked waiting for tokens" true (elapsed_seconds >= 0.4) + +let test_submit_sync_nonexistent () = + let table = Bucket_table.create () in + let result = + Bucket_table.submit_sync table ~client_id:(key "nonexistent") + ~callback:(fun () -> 99) + 1.0 + in + Alcotest.(check int) + "submit_sync on nonexistent bucket runs callback immediately" 99 result + +let test_submit_sync_with_queued_items () = + (* Test that submit_sync respects FIFO ordering when queue has items *) + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:5.0 + ~fill_rate:10.0 + in + (* Drain the bucket *) + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 5.0 in + let execution_order = ref [] in + let order_mutex = Mutex.create () in + let record_execution id = + Mutex.lock order_mutex ; + execution_order := id :: !execution_order ; + Mutex.unlock order_mutex + in + (* Submit async items first *) + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 1) + 1.0 ; + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 2) + 1.0 ; + (* Now submit_sync should queue behind the async items *) + let result = + Bucket_table.submit_sync table ~client_id:(key "agent1") + ~callback:(fun () -> record_execution 3 ; "sync_result") + 1.0 + in + Alcotest.(check string) + "submit_sync returns correct result" "sync_result" result ; + let order = List.rev !execution_order in + Alcotest.(check (list int)) + "submit_sync executes after queued items" [1; 2; 3] order + +let test_submit_sync_concurrent () = + (* Test multiple concurrent submit_sync calls *) + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:1.0 + ~fill_rate:10.0 + in + (* Drain the bucket to force queueing *) + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 1.0 in + let results = Array.make 5 0 in + let threads = + Array.init 5 (fun i -> + Thread.create + (fun () -> + let r = + Bucket_table.submit_sync table ~client_id:(key "agent1") + ~callback:(fun () -> i + 1) + 1.0 + in + results.(i) <- r + ) + () + ) + in + Array.iter Thread.join threads ; + (* Each thread should get its own result back *) + for i = 0 to 4 do + Alcotest.(check int) + (Printf.sprintf "thread %d gets correct result" i) + (i + 1) results.(i) + done + +let test_submit_sync_interleaved () = + (* Test interleaving submit and submit_sync *) + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "agent1") ~burst_size:2.0 + ~fill_rate:10.0 + in + (* Drain the bucket *) + let _ = Bucket_table.try_consume table ~client_id:(key "agent1") 2.0 in + let async_executed = ref false in + (* Submit async first *) + Bucket_table.submit table ~client_id:(key "agent1") + ~callback:(fun () -> async_executed := true) + 1.0 ; + (* Submit sync should wait for async to complete first *) + let sync_result = + Bucket_table.submit_sync table ~client_id:(key "agent1") + ~callback:(fun () -> !async_executed) + 1.0 + in + Alcotest.(check bool) + "sync callback sees async already executed" true sync_result + +let test_concurrent_add_delete_stress () = + (* Stress test: rapidly add and delete entries. + Without proper locking, hashtable can get corrupted. *) + let table = Bucket_table.create () in + let iterations = 1000 in + let num_keys = 10 in + let errors = ref 0 in + let errors_mutex = Mutex.create () in + let add_threads = + Array.init 5 (fun t -> + Thread.create + (fun () -> + for i = 0 to iterations - 1 do + let k = + Printf.sprintf "key%d" (((t * iterations) + i) mod num_keys) + in + let _ = + Bucket_table.add_bucket table ~client_id:(key k) + ~burst_size:10.0 ~fill_rate:1.0 + in + () + done + ) + () + ) + in + let delete_threads = + Array.init 5 (fun t -> + Thread.create + (fun () -> + for i = 0 to iterations - 1 do + let k = + Printf.sprintf "key%d" (((t * iterations) + i) mod num_keys) + in + Bucket_table.delete_bucket table ~client_id:(key k) + done + ) + () + ) + in + let read_threads = + Array.init 5 (fun t -> + Thread.create + (fun () -> + for i = 0 to iterations - 1 do + let k = + Printf.sprintf "key%d" (((t * iterations) + i) mod num_keys) + in + (* This should never crash, even if key doesn't exist *) + try + let _ = Bucket_table.peek table ~client_id:(key k) in + () + with _ -> + Mutex.lock errors_mutex ; + incr errors ; + Mutex.unlock errors_mutex + done + ) + () + ) + in + Array.iter Thread.join add_threads ; + Array.iter Thread.join delete_threads ; + Array.iter Thread.join read_threads ; + Alcotest.(check int) "No errors during concurrent operations" 0 !errors + +let test_consume_during_delete_race () = + (* Test that try_consume doesn't crash when bucket is being deleted. + Without proper locking, we could try to access a deleted bucket. *) + let iterations = 500 in + let errors = ref 0 in + let errors_mutex = Mutex.create () in + for _ = 1 to iterations do + let table = Bucket_table.create () in + let _ = + Bucket_table.add_bucket table ~client_id:(key "target") ~burst_size:100.0 + ~fill_rate:1.0 + in + let barrier = ref 0 in + let barrier_mutex = Mutex.create () in + let consumer = + Thread.create + (fun () -> + Mutex.lock barrier_mutex ; + incr barrier ; + Mutex.unlock barrier_mutex ; + while + Mutex.lock barrier_mutex ; + let b = !barrier in + Mutex.unlock barrier_mutex ; b < 2 + do + Thread.yield () + done ; + try + let _ = + Bucket_table.try_consume table ~client_id:(key "target") 1.0 + in + () + with _ -> + Mutex.lock errors_mutex ; incr errors ; Mutex.unlock errors_mutex + ) + () + in + let deleter = + Thread.create + (fun () -> + Mutex.lock barrier_mutex ; + incr barrier ; + Mutex.unlock barrier_mutex ; + while + Mutex.lock barrier_mutex ; + let b = !barrier in + Mutex.unlock barrier_mutex ; b < 2 + do + Thread.yield () + done ; + Bucket_table.delete_bucket table ~client_id:(key "target") + ) + () + in + Thread.join consumer ; Thread.join deleter + done ; + Alcotest.(check int) "No crashes during consume/delete race" 0 !errors + +(* Wildcard matching tests *) + +let test_wildcard_user_agent_matches_any () = + (* A bucket with empty user_agent field should match any user_agent header *) + let table = Bucket_table.create () in + let pattern = Bucket_table.Key.{user_agent= ""; host_ip= "192.168.1.1"} in + let _ = + Bucket_table.add_bucket table ~client_id:pattern ~burst_size:10.0 + ~fill_rate:1.0 + in + (* Should match any user_agent header value with same host_ip *) + let client1 = Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} in + let client2 = Bucket_table.Key.{user_agent= "wget"; host_ip= "192.168.1.1"} in + let client3 = Bucket_table.Key.{user_agent= ""; host_ip= "192.168.1.1"} in + Alcotest.(check bool) + "wildcard user_agent matches curl" true + (Bucket_table.mem table ~client_id:client1) ; + Alcotest.(check bool) + "wildcard user_agent matches wget" true + (Bucket_table.mem table ~client_id:client2) ; + Alcotest.(check bool) + "wildcard user_agent matches empty" true + (Bucket_table.mem table ~client_id:client3) ; + (* Should not match different host_ip *) + let client_other = + Bucket_table.Key.{user_agent= "curl"; host_ip= "10.0.0.1"} + in + Alcotest.(check bool) + "{user_agent=curl, host_ip=10.0.0.1} does not match {user_agent=*, \ + host_ip=192.168.1.1}" + false + (Bucket_table.mem table ~client_id:client_other) + +let test_wildcard_host_ip_matches_any () = + (* A bucket with empty host_ip should match any host_ip *) + let table = Bucket_table.create () in + let pattern = Bucket_table.Key.{user_agent= "curl"; host_ip= ""} in + let _ = + Bucket_table.add_bucket table ~client_id:pattern ~burst_size:10.0 + ~fill_rate:1.0 + in + (* Should match any host_ip with same user_agent header *) + let client1 = Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} in + let client2 = Bucket_table.Key.{user_agent= "curl"; host_ip= "10.0.0.1"} in + let client3 = Bucket_table.Key.{user_agent= "curl"; host_ip= ""} in + Alcotest.(check bool) + "wildcard host_ip matches 192.168.1.1" true + (Bucket_table.mem table ~client_id:client1) ; + Alcotest.(check bool) + "wildcard host_ip matches 10.0.0.1" true + (Bucket_table.mem table ~client_id:client2) ; + Alcotest.(check bool) + "wildcard host_ip matches empty" true + (Bucket_table.mem table ~client_id:client3) ; + (* Should not match different user_agent header *) + let client_other = + Bucket_table.Key.{user_agent= "wget"; host_ip= "192.168.1.1"} + in + Alcotest.(check bool) + "wildcard does not match different user_agent" false + (Bucket_table.mem table ~client_id:client_other) + +let test_wildcard_match_priority_exact_first () = + (* Exact match should take priority over wildcards *) + let table = Bucket_table.create () in + let exact = Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} in + let wildcard_ua = Bucket_table.Key.{user_agent= ""; host_ip= "192.168.1.1"} in + let wildcard_ip = Bucket_table.Key.{user_agent= "curl"; host_ip= ""} in + (* Add in reverse priority order to test sorting *) + let _ = + Bucket_table.add_bucket table ~client_id:wildcard_ua ~burst_size:5.0 + ~fill_rate:1.0 + in + let _ = + Bucket_table.add_bucket table ~client_id:wildcard_ip ~burst_size:15.0 + ~fill_rate:1.0 + in + let _ = + Bucket_table.add_bucket table ~client_id:exact ~burst_size:10.0 + ~fill_rate:1.0 + in + (* Lookup with exact key should return exact bucket (10.0), not wildcards *) + let client = Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} in + Alcotest.(check (option (float 0.1))) + "exact match takes priority" (Some 10.0) + (Bucket_table.peek table ~client_id:client) + +let test_wildcard_match_priority_host_ip_over_user_agent () = + (* host_ip wildcard (user_agent specified) should match before + user_agent wildcard (host_ip specified) *) + let table = Bucket_table.create () in + let wildcard_ua = Bucket_table.Key.{user_agent= ""; host_ip= "192.168.1.1"} in + let wildcard_ip = Bucket_table.Key.{user_agent= "curl"; host_ip= ""} in + (* Add user_agent wildcard first *) + let _ = + Bucket_table.add_bucket table ~client_id:wildcard_ua ~burst_size:5.0 + ~fill_rate:1.0 + in + (* Add host_ip wildcard second *) + let _ = + Bucket_table.add_bucket table ~client_id:wildcard_ip ~burst_size:15.0 + ~fill_rate:1.0 + in + (* Lookup should prefer host_ip wildcard (15.0) over user_agent wildcard (5.0) *) + let client = Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} in + Alcotest.(check (option (float 0.1))) + "host_ip wildcard takes priority over user_agent wildcard" (Some 15.0) + (Bucket_table.peek table ~client_id:client) + +let test_no_spurious_wildcard_matches () = + (* Ensure wildcards don't match when they shouldn't *) + let table = Bucket_table.create () in + let pattern1 = + Bucket_table.Key.{user_agent= "curl"; host_ip= "192.168.1.1"} + in + let pattern2 = Bucket_table.Key.{user_agent= "wget"; host_ip= ""} in + let _ = + Bucket_table.add_bucket table ~client_id:pattern1 ~burst_size:10.0 + ~fill_rate:1.0 + in + let _ = + Bucket_table.add_bucket table ~client_id:pattern2 ~burst_size:20.0 + ~fill_rate:1.0 + in + (* Client with different user_agent and host_ip should not match pattern1 *) + let client1 = Bucket_table.Key.{user_agent= "curl"; host_ip= "10.0.0.1"} in + Alcotest.(check bool) + "{user_agent=curl, host_ip=10.0.0.1} does not match {user_agent=curl, \ + host_ip=192.168.1.1}" + false + (Bucket_table.mem table ~client_id:client1) ; + (* Client with matching user_agent but different host_ip should match pattern2 *) + let client2 = Bucket_table.Key.{user_agent= "wget"; host_ip= "10.0.0.1"} in + Alcotest.(check (option (float 0.1))) + "{user_agent=wget, host_ip=10.0.0.1} matches {user_agent=wget, host_ip=*} \ + wildcard" + (Some 20.0) + (Bucket_table.peek table ~client_id:client2) ; + (* Client with no matching pattern *) + let client3 = + Bucket_table.Key.{user_agent= "firefox"; host_ip= "172.16.0.1"} + in + Alcotest.(check bool) + "{user_agent=firefox, host_ip=172.16.0.1} has no match" false + (Bucket_table.mem table ~client_id:client3) + +let test_reject_all_wildcard_key () = + (* Keys with both fields empty should be rejected *) + let table = Bucket_table.create () in + let all_wildcard = Bucket_table.Key.{user_agent= ""; host_ip= ""} in + let success = + Bucket_table.add_bucket table ~client_id:all_wildcard ~burst_size:10.0 + ~fill_rate:1.0 + in + Alcotest.(check bool) "all-wildcard key rejected" false success + +let test = + [ + ("Create empty table", `Quick, test_create) + ; ("Add valid bucket", `Quick, test_add_bucket) + ; ("Add invalid bucket", `Quick, test_add_bucket_invalid) + ; ("Delete bucket", `Quick, test_delete_bucket) + ; ("Delete nonexistent bucket", `Quick, test_delete_nonexistent) + ; ("Try consume", `Quick, test_try_consume) + ; ("Try consume insufficient", `Quick, test_try_consume_insufficient) + ; ("Try consume nonexistent", `Quick, test_try_consume_nonexistent) + ; ("Peek nonexistent", `Quick, test_peek_nonexistent) + ; ("Multiple agents", `Quick, test_multiple_agents) + ; ("Submit", `Slow, test_submit) + ; ("Submit nonexistent", `Quick, test_submit_nonexistent) + ; ("Submit fairness", `Slow, test_submit_fairness) + ; ("Submit sync", `Slow, test_submit_sync) + ; ("Submit sync interleaved", `Slow, test_submit_sync_interleaved) + ; ("Submit sync nonexistent", `Slow, test_submit_sync_nonexistent) + ; ("Submit sync concurrent", `Slow, test_submit_sync_concurrent) + ; ("Submit sync with queue", `Slow, test_submit_sync_with_queued_items) + ; ("Concurrent add/delete stress", `Quick, test_concurrent_add_delete_stress) + ; ("Consume during delete race", `Quick, test_consume_during_delete_race) + ; ( "Wildcard user_agent matches any" + , `Quick + , test_wildcard_user_agent_matches_any + ) + ; ("Wildcard host_ip matches any", `Quick, test_wildcard_host_ip_matches_any) + ; ( "Wildcard priority: exact first" + , `Quick + , test_wildcard_match_priority_exact_first + ) + ; ( "Wildcard priority: host_ip over user_agent" + , `Quick + , test_wildcard_match_priority_host_ip_over_user_agent + ) + ; ("No spurious wildcard matches", `Quick, test_no_spurious_wildcard_matches) + ; ("Reject all-wildcard key", `Quick, test_reject_all_wildcard_key) + ] + +let () = Alcotest.run "Bucket table library" [("Bucket table tests", test)] diff --git a/ocaml/libs/rate-limit/test/test_bucket_table.mli b/ocaml/libs/rate-limit/test/test_bucket_table.mli new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ocaml/libs/rate-limit/test/test_lru.ml b/ocaml/libs/rate-limit/test/test_lru.ml new file mode 100644 index 00000000000..d68ca369a36 --- /dev/null +++ b/ocaml/libs/rate-limit/test/test_lru.ml @@ -0,0 +1,120 @@ +(* + * Copyright (C) 2023 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +open Rate_limit +module LRU = Lru +module LL = Lru.LL + +(* Generators *) +let chars = + let open QCheck in + list_of_size Gen.(int_range 0 50) char + +let ints = + let open QCheck in + list_of_size Gen.(int_range 0 50) int + +let kvs = + let open QCheck in + list_of_size Gen.(int_range 1 20) (pair char int) + +let lru = + let open QCheck in + kvs + |> map @@ fun kvs -> + let lru = LRU.create (List.length kvs) in + List.iter (fun (k, v) -> LRU.add lru k v |> ignore) kvs ; + lru + +(* Tests *) + +let count = 1000 + +let test_ll_from_to_list = + QCheck.Test.make ~name:"LL from_list/to_list roundtrip" chars ~count + @@ fun chars -> + let t = LL.from_list chars in + LL.to_list t = chars + +let test_ll_append_drop = + QCheck.Test.make ~name:"LL append and drop" chars ~count @@ fun chars -> + let open LL in + let t = from_list chars in + let x = node 'x' in + let y = node 'y' in + let z = node 'z' in + List.iter (append t) [x; y; z] ; + assert (match last t with Some z -> value z = 'z' | None -> false) ; + List.iter (drop t) [y; z; x] ; + to_list t = chars + +let test_ll_fold = + QCheck.Test.make ~name:"LL foldl/foldr consistency" ints ~count @@ fun ints -> + let total = List.fold_left ( + ) 0 ints in + let open LL in + let t = from_list ints in + List.for_all (( = ) total) [foldl ( + ) 0 t; foldr ( + ) t 0] + +let test_lru_length = + QCheck.Test.make ~name:"LRU length matches to_list" lru ~count @@ fun lru -> + LRU.to_list lru |> List.length = LRU.size lru + +let test_lru_drop = + QCheck.Test.make ~name:"LRU drop_while evicts all" lru ~count @@ fun lru -> + let evict (_, _) _ = true in + LRU.drop_while ~evict lru ; + LRU.size lru = 0 && LRU.cap lru > 0 + +(** add a new value but make room if the cache is full *) +let add lru (key, value) = + match LRU.add lru key value with true -> LRU.trim lru | false -> () + +(** The test takes a full cache and adds more elements but now elements + are trimmed such that the cache does not grow *) +let test_lru_growth = + QCheck.Test.make ~name:"LRU growth respects capacity" + QCheck.(pair lru kvs) + ~count + @@ fun (lru, kvs) -> + List.iter (add lru) kvs ; + LRU.size lru <= LRU.cap lru + +(* We expect to find all keys; we sort the keys before looking them up. + The least recently used key should be the head of that list *) +let test_lru_lookup = + QCheck.Test.make ~name:"LRU lookup finds all keys" lru ~count @@ fun lru -> + let sort = List.sort_uniq Char.compare in + let keys = List.map fst (LRU.to_list lru) |> sort in + let lookup key = + match LRU.lookup lru key with + | Some _ -> + () + | None -> + failwith "failed to find key" + in + List.iter lookup keys ; + LRU.to_list lru |> List.map fst |> sort = keys + +let test = + [ + QCheck_alcotest.to_alcotest test_ll_from_to_list + ; QCheck_alcotest.to_alcotest test_ll_append_drop + ; QCheck_alcotest.to_alcotest test_ll_fold + ; QCheck_alcotest.to_alcotest test_lru_length + ; QCheck_alcotest.to_alcotest test_lru_drop + ; QCheck_alcotest.to_alcotest test_lru_growth + ; QCheck_alcotest.to_alcotest test_lru_lookup + ] + +let () = Alcotest.run "LRU library" [("LRU tests", test)] diff --git a/ocaml/libs/rate-limit/test/test_lru.mli b/ocaml/libs/rate-limit/test/test_lru.mli new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ocaml/libs/rate-limit/test/test_token_bucket.ml b/ocaml/libs/rate-limit/test/test_token_bucket.ml new file mode 100644 index 00000000000..75038b36c7a --- /dev/null +++ b/ocaml/libs/rate-limit/test/test_token_bucket.ml @@ -0,0 +1,409 @@ +open Thread +open Rate_limit + +let test_bad_fill_rate () = + let tb_zero = Token_bucket.create ~burst_size:1.0 ~fill_rate:0.0 in + Alcotest.(check bool) + "Creating a token bucket with 0 fill rate should fail" true (tb_zero = None) ; + let tb_negative = Token_bucket.create ~burst_size:1.0 ~fill_rate:~-.1.0 in + Alcotest.(check bool) + "Creating a token bucket with negative fill rate should fail" true + (tb_negative = None) + +let test_consume_removes_correct_amount () = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size:10.0 + ~fill_rate:2.0 + ) + in + + Alcotest.(check (float 0.0)) + "Initial tokens should be burst_size" 10.0 + (Token_bucket.peek_with_timestamp initial_time tb) ; + + let consume_time = Mtime.Span.of_uint64_ns 1_000_000_000L in + let success = + Token_bucket.consume_with_timestamp (fun () -> consume_time) tb 3.0 + in + Alcotest.(check bool) "Consume 3 tokens should succeed" true success ; + Alcotest.(check (float 0.0)) + "After consume, tokens should be 7" 7.0 + (Token_bucket.peek_with_timestamp consume_time tb) + +let test_consume_more_than_available () = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size:5.0 + ~fill_rate:1.0 + ) + in + + let _ = Token_bucket.consume_with_timestamp (fun () -> initial_time) tb 4.0 in + + let consume_time = Mtime.Span.of_uint64_ns 1_000_000_000L in + let success = + Token_bucket.consume_with_timestamp (fun () -> consume_time) tb 10.0 + in + Alcotest.(check bool) "Consume more than available should fail" false success ; + Alcotest.(check (float 0.0)) + "After failed consume, tokens should be 2" 2.0 + (Token_bucket.peek_with_timestamp consume_time tb) + +let test_consume_refills_before_removing () = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size:10.0 + ~fill_rate:2.0 + ) + in + + let first_consume = + Token_bucket.consume_with_timestamp (fun () -> initial_time) tb 5.0 + in + Alcotest.(check bool) "First consume should succeed" true first_consume ; + + let later_time = Mtime.Span.of_uint64_ns 3_000_000_000L in + let second_consume = + Token_bucket.consume_with_timestamp (fun () -> later_time) tb 8.0 + in + + Alcotest.(check bool) + "Second consume after refill should succeed" true second_consume ; + + Alcotest.(check (float 0.0)) + "After refill and consume, tokens should be 2" 2.0 + (Token_bucket.peek_with_timestamp later_time tb) + +let test_peek_respects_burst_size () = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size:10.0 + ~fill_rate:5.0 + ) + in + + let _ = Token_bucket.consume_with_timestamp (fun () -> initial_time) tb 8.0 in + + let later_time = Mtime.Span.of_uint64_ns 10_000_000_000L in + let available = Token_bucket.peek_with_timestamp later_time tb in + Alcotest.(check (float 0.0)) + "Peek should respect burst_size limit" 10.0 available + +let test_concurrent_access () = + let tb = + Option.get + (Token_bucket.create_with_timestamp Mtime.Span.zero ~burst_size:15.0 + ~fill_rate:0.01 + ) + in + let threads = + Array.init 10 (fun _ -> + create + (fun () -> + Token_bucket.consume_with_timestamp + (fun () -> Mtime.Span.zero) + tb 1.0 + ) + () + ) + in + Array.iter Thread.join threads ; + Alcotest.(check (float 0.0)) + "Threads consuming concurrently should all remove from token amount" + (Token_bucket.peek_with_timestamp Mtime.Span.zero tb) + 5.0 + +let test_sleep () = + let tb = Option.get (Token_bucket.create ~burst_size:20.0 ~fill_rate:5.0) in + let _ = Token_bucket.consume tb 10.0 in + Thread.delay 1.0 ; + Alcotest.(check (float 0.5)) + "Sleep 1 should refill token bucket by fill_rate" 15.0 (Token_bucket.peek tb) + +let test_system_time_versions () = + let tb = Option.get (Token_bucket.create ~burst_size:10.0 ~fill_rate:2.0) in + + let initial_peek = Token_bucket.peek tb in + Alcotest.(check (float 0.01)) + "System time peek should return burst_size initially" 10.0 initial_peek ; + + let consume_result = Token_bucket.consume tb 3.0 in + Alcotest.(check bool) "System time consume should succeed" true consume_result ; + + let after_consume_peek = Token_bucket.peek tb in + Alcotest.(check (float 0.01)) + "After consume, should have 7 tokens" 7.0 after_consume_peek + +let test_concurrent_system_time () = + let tb = Option.get (Token_bucket.create ~burst_size:100.0 ~fill_rate:10.0) in + let num_threads = 20 in + let consume_per_thread = 3 in + + let threads = + Array.init num_threads (fun _ -> + create + (fun () -> + for _ = 1 to consume_per_thread do + ignore (Token_bucket.consume tb 1.0) + done + ) + () + ) + in + Array.iter Thread.join threads ; + + let remaining = Token_bucket.peek tb in + let expected_remaining = + 100.0 -. float_of_int (num_threads * consume_per_thread) + in + Alcotest.(check (float 0.1)) + "Concurrent system time consumption should work correctly" + expected_remaining remaining + +let test_consume_more_than_available_concurrent () = + let tb = + Option.get + (Token_bucket.create_with_timestamp Mtime.Span.zero ~burst_size:5.0 + ~fill_rate:0.1 + ) + in + let num_threads = 10 in + let consume_per_thread = 1 in + let successful_consumes = ref 0 in + let counter_mutex = Mutex.create () in + + let threads = + Array.init num_threads (fun _ -> + create + (fun () -> + let success = + Token_bucket.consume_with_timestamp + (fun () -> Mtime.Span.zero) + tb + (float_of_int consume_per_thread) + in + if success then ( + Mutex.lock counter_mutex ; + incr successful_consumes ; + Mutex.unlock counter_mutex + ) + ) + () + ) + in + Array.iter Thread.join threads ; + + Alcotest.(check int) + "Only 5 consumptions should succeed" 5 !successful_consumes ; + Alcotest.(check (float 0.1)) + "Bucket should be empty after consumptions" 0.0 + (Token_bucket.peek_with_timestamp Mtime.Span.zero tb) + +let test_delay_until_available () = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size:10.0 + ~fill_rate:2.0 + ) + in + + let _ = + Token_bucket.consume_with_timestamp (fun () -> initial_time) tb 10.0 + in + + let delay = + Token_bucket.get_delay_until_available_timestamp initial_time tb 4.0 + in + Alcotest.(check (float 0.01)) + "Delay for 4 tokens at 2 tokens/sec should be 2 seconds" 2.0 delay ; + + let tb_fresh = + Option.get (Token_bucket.create ~burst_size:10.0 ~fill_rate:2.0) + in + let _ = Token_bucket.consume tb_fresh 10.0 in + let delay_system = Token_bucket.get_delay_until_available tb_fresh 4.0 in + + Alcotest.(check (float 0.1)) + "System time delay should be approximately 2 seconds" 2.0 delay_system + +let test_edge_cases () = + let tb = + Option.get + (Token_bucket.create_with_timestamp Mtime.Span.zero ~burst_size:5.0 + ~fill_rate:1.0 + ) + in + let success = + Token_bucket.consume_with_timestamp (fun () -> Mtime.Span.zero) tb 0.0 + in + Alcotest.(check bool) "Consuming zero tokens should succeed" true success ; + + let tb_small = + Option.get + (Token_bucket.create_with_timestamp Mtime.Span.zero ~burst_size:1.0 + ~fill_rate:0.1 + ) + in + let success_small = + Token_bucket.consume_with_timestamp + (fun () -> Mtime.Span.zero) + tb_small 0.001 + in + Alcotest.(check bool) + "Consuming very small amount should succeed" true success_small + +let test_consume_quickcheck = + let open QCheck.Gen in + let gen_operations = + let gen_operation = + pair (float_range 0.0 1000.0) (int_range 0 1_000_000_000) + in + list_size (int_range 1 50) gen_operation + in + + let fail_peek op_num time_ns time_delta expected current added actual diff = + QCheck.Test.fail_reportf + "Operation %d: peek failed\n\ + \ Time: %d ns (delta: %d ns)\n\ + \ Expected tokens: %.3f (current: %.3f + added: %.3f)\n\ + \ Actual tokens: %.3f\n\ + \ Diff: %.6f" + op_num time_ns time_delta expected current added actual diff + in + + let fail_consume op_num time_ns time_delta amount available success expected + actual diff = + QCheck.Test.fail_reportf + "Operation %d: consume failed\n\ + \ Time: %d ns (delta: %d ns)\n\ + \ Consume amount: %.3f\n\ + \ Available before: %.3f\n\ + \ Success: %b\n\ + \ Expected after: %.3f\n\ + \ Actual after: %.3f\n\ + \ Diff: %.6f" + op_num time_ns time_delta amount available success expected actual diff + in + + let property (burst_size, fill_rate, operations) = + let initial_time = Mtime.Span.of_uint64_ns 0L in + let tb = + Option.get + (Token_bucket.create_with_timestamp initial_time ~burst_size ~fill_rate) + in + + let rec check_operations op_num time_ns last_refill_ns current_tokens ops = + match ops with + | [] -> + true + | (consume_amount, time_delta_ns) :: rest -> + let new_time_ns = time_ns + time_delta_ns in + let current_time = + Mtime.Span.of_uint64_ns (Int64.of_int new_time_ns) + in + let time_since_refill_seconds = + float_of_int (new_time_ns - last_refill_ns) *. 1e-9 + in + let tokens_added = time_since_refill_seconds *. fill_rate in + let expected_available = + min burst_size (current_tokens +. tokens_added) + in + let actual_before = + Token_bucket.peek_with_timestamp current_time tb + in + let peek_diff = abs_float (actual_before -. expected_available) in + + if peek_diff >= 0.001 then + fail_peek op_num new_time_ns time_delta_ns expected_available + current_tokens tokens_added actual_before peek_diff + else + let success = + Token_bucket.consume_with_timestamp + (fun () -> current_time) + tb consume_amount + in + let actual_after = + Token_bucket.peek_with_timestamp current_time tb + in + let new_tokens = + if success then + expected_available -. consume_amount + else + expected_available + in + let after_diff = abs_float (actual_after -. new_tokens) in + + if after_diff >= 0.001 then + fail_consume op_num new_time_ns time_delta_ns consume_amount + expected_available success new_tokens actual_after after_diff + else + check_operations (op_num + 1) new_time_ns new_time_ns new_tokens + rest + in + + check_operations 1 0 0 burst_size operations + in + + let gen_all = + map3 + (fun burst fill ops -> (burst, fill, ops)) + pfloat (float_range 1e-9 1e9) gen_operations + in + + let arb_all = + QCheck.make + ~print:(fun (burst, fill, ops) -> + let ops_str = + ops + |> List.mapi (fun i (amount, delta) -> + Printf.sprintf " Op %d: consume %.3f at +%d ns" (i + 1) amount + delta + ) + |> String.concat "\n" + in + Printf.sprintf "burst_size=%.3f, fill_rate=%.3f, %d operations:\n%s" + burst fill (List.length ops) ops_str + ) + gen_all + in + + QCheck.Test.make ~name:"Consume operations maintain correct token count" + ~count:100 arb_all (fun (burst, fill, ops) -> property (burst, fill, ops) + ) + +let test = + [ + ( "A bucket with zero or negative fill rate cannot be created" + , `Quick + , test_bad_fill_rate + ) + ; ( "Consume removes correct amount" + , `Quick + , test_consume_removes_correct_amount + ) + ; ("Consume more than available", `Quick, test_consume_more_than_available) + ; ( "Consume refills before removing" + , `Quick + , test_consume_refills_before_removing + ) + ; ("Peek respects burst size", `Quick, test_peek_respects_burst_size) + ; ("Concurrent access", `Quick, test_concurrent_access) + ; ("Refill after sleep", `Slow, test_sleep) + ; ("System time versions", `Quick, test_system_time_versions) + ; ("Concurrent system time", `Quick, test_concurrent_system_time) + ; ( "Consume more than available concurrent" + , `Quick + , test_consume_more_than_available_concurrent + ) + ; ("Delay until available", `Quick, test_delay_until_available) + ; ("Edge cases", `Quick, test_edge_cases) + ; QCheck_alcotest.to_alcotest test_consume_quickcheck + ] + +let () = Alcotest.run "Token bucket library" [("Token bucket tests", test)] diff --git a/ocaml/libs/rate-limit/test/test_token_bucket.mli b/ocaml/libs/rate-limit/test/test_token_bucket.mli new file mode 100644 index 00000000000..e69de29bb2d diff --git a/ocaml/libs/rate-limit/token_bucket.ml b/ocaml/libs/rate-limit/token_bucket.ml new file mode 100644 index 00000000000..d59683e02e5 --- /dev/null +++ b/ocaml/libs/rate-limit/token_bucket.ml @@ -0,0 +1,83 @@ +(* + * Copyright (C) 2025 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +type state = {tokens: float; last_refill: Mtime.span} + +type t = {burst_size: float; fill_rate: float; state: state Atomic.t} + +let create_with_timestamp timestamp ~burst_size ~fill_rate = + if fill_rate <= 0. then + None + else + let state = Atomic.make {tokens= burst_size; last_refill= timestamp} in + Some {burst_size; fill_rate; state} + +let create = create_with_timestamp (Mtime_clock.elapsed ()) + +let compute_tokens timestamp {tokens; last_refill} ~burst_size ~fill_rate = + let time_delta = Mtime.Span.abs_diff last_refill timestamp in + let time_delta_seconds = Mtime.Span.to_float_ns time_delta *. 1e-9 in + min burst_size (tokens +. (time_delta_seconds *. fill_rate)) + +let peek_with_timestamp timestamp tb = + let tb_state = Atomic.get tb.state in + compute_tokens timestamp tb_state ~burst_size:tb.burst_size + ~fill_rate:tb.fill_rate + +let peek tb = peek_with_timestamp (Mtime_clock.elapsed ()) tb + +let consume_with_timestamp get_time tb amount = + let rec try_consume () = + let timestamp = get_time () in + let old_state = Atomic.get tb.state in + let new_tokens = + compute_tokens timestamp old_state ~burst_size:tb.burst_size + ~fill_rate:tb.fill_rate + in + let success, final_tokens = + if new_tokens >= amount then + (true, new_tokens -. amount) + else + (false, new_tokens) + in + let new_state = {tokens= final_tokens; last_refill= timestamp} in + if Atomic.compare_and_set tb.state old_state new_state then + success + else + try_consume () + in + try_consume () + +let consume = consume_with_timestamp Mtime_clock.elapsed + +let get_delay_until_available_timestamp timestamp tb amount = + let {tokens; last_refill} = Atomic.get tb.state in + let current_tokens = + compute_tokens timestamp {tokens; last_refill} ~burst_size:tb.burst_size + ~fill_rate:tb.fill_rate + in + let required_tokens = max 0. (amount -. current_tokens) in + required_tokens /. tb.fill_rate + +let get_delay_until_available tb amount = + get_delay_until_available_timestamp (Mtime_clock.elapsed ()) tb amount + +(* This implementation only works when there is only one thread trying to + consume - fairness needs to be implemented on top of it with a queue. + If there is no contention, it should only delay once. *) +let rec delay_then_consume tb amount = + if not (consume tb amount) then ( + Thread.delay (get_delay_until_available tb amount) ; + delay_then_consume tb amount + ) diff --git a/ocaml/libs/rate-limit/token_bucket.mli b/ocaml/libs/rate-limit/token_bucket.mli new file mode 100644 index 00000000000..d04f4fd6174 --- /dev/null +++ b/ocaml/libs/rate-limit/token_bucket.mli @@ -0,0 +1,106 @@ +(* + * Copyright (C) 2025 Cloud Software Group + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +(** This module implements a classic token-bucket rate limiter. Token buckets + contain tokens that are refilled over time, and can be consumed in a + thread-safe way. A token bucket accumulates [fill_rate] tokens per second, + up to [burst_size]. Consumers may take tokens (if available), or query when + enough tokens will become available. + + Token buckets implement rate limiting by allowing operations to proceed + only when sufficient tokens are available - otherwise, the operations can + be delayed until enough tokens are available. + + To avoid doing unnecessary work to refill the bucket, token amounts are + only updated when a consume operation is carried out. The buckets keep a + last_refill timestamp which is updated on consume in tandem with the token + counts, and informs how many tokens should be added by the bucket refill. + + We include versions of functions that take a timestamp as a parameter for + testing purposes only - consumers of this library should use the + timestamp-less versions. +*) + +type t + +val create : burst_size:float -> fill_rate:float -> t option +(** Create token bucket with given parameters. + Returns None if the fill rate is 0 or negative. + @param burst_size Maximum number of tokens that can fit in the bucket + @param fill_rate Number of tokens added to the bucket per second + *) + +val peek : t -> float +(** Retrieve current token amount + @param tb Token bucket + @return Amount of tokens in the token bucket + *) + +val consume : t -> float -> bool +(** Consume tokens from the bucket in a thread-safe manner. + @param tb Token bucket + @param amount How many tokens to consume + @return Whether the tokens were successfully consumed + *) + +val get_delay_until_available : t -> float -> float +(** Get number of seconds that need to pass until bucket is expected to have + enough tokens to fulfil the request + @param tb Token bucket + @param amount How many tokens we want to consume + @return Number of seconds until tokens are available +*) + +val delay_then_consume : t -> float -> unit + +(**/**) + +(* Fuctions accepting a timestamp are meant for testing only *) + +val create_with_timestamp : + Mtime.span -> burst_size:float -> fill_rate:float -> t option +(** Create token bucket with given parameters and supplied inital timestamp + Returns None if the fill_rate is 0 or negative. + @param timestamp Initial timestamp + @param burst_size Maximum number of tokens that can fit in the bucket + @param fill_rate Number of tokens added to the bucket per second + *) + +val peek_with_timestamp : Mtime.span -> t -> float +(** Retrieve token amount in token bucket at given timestamp. + Undefined behaviour when [timestamp] <= [tb.timestamp] + @param timestamp Current time + @param tb Token bucket + @return Amount of tokens in the token bucket + *) + +val consume_with_timestamp : (unit -> Mtime.span) -> t -> float -> bool +(** Consume tokens from the bucket in a thread-safe manner, using supplied + function for obtaining the current time + @param get_time Function to obtain timestamp, e.g. Mtime_clock.elapsed + @param tb Token bucket + @param amount How many tokens to consume + @return Whether the tokens were successfully consumed + *) + +val get_delay_until_available_timestamp : Mtime.span -> t -> float -> float +(** Get number of seconds that need to pass until bucket is expected to have + enough tokens to fulfil the request + @param timestamp + @param tb Token bucket + @param amount How many tokens we want to consume + @return Number of seconds until tokens are available +*) + +(**/**) diff --git a/ocaml/libs/uuid/uuidx.ml b/ocaml/libs/uuid/uuidx.ml index b22c22ebd14..8ae23a84052 100644 --- a/ocaml/libs/uuid/uuidx.ml +++ b/ocaml/libs/uuid/uuidx.ml @@ -64,6 +64,7 @@ type without_secret = | `sr_stat | `subject | `task + | `Rate_limit | `tunnel | `USB_group | `user diff --git a/ocaml/libs/uuid/uuidx.mli b/ocaml/libs/uuid/uuidx.mli index bd0865cf628..e3346480998 100644 --- a/ocaml/libs/uuid/uuidx.mli +++ b/ocaml/libs/uuid/uuidx.mli @@ -75,6 +75,7 @@ type without_secret = | `sr_stat | `subject | `task + | `Rate_limit | `tunnel | `USB_group | `user diff --git a/ocaml/xapi-cli-server/cli_frontend.ml b/ocaml/xapi-cli-server/cli_frontend.ml index d8185da9d47..9f046321865 100644 --- a/ocaml/xapi-cli-server/cli_frontend.ml +++ b/ocaml/xapi-cli-server/cli_frontend.ml @@ -3869,6 +3869,28 @@ let rec cmdtable_data : (string * cmd_spec) list = ; flags= [] } ) + ; ( "rate-limit-create" + , { + reqd= ["burst-size"; "fill-rate"] + ; optn= ["user-agent"; "host-ip"] + ; help= + "Add rate limiting to an XAPI client by identifying it via user \ + agent, IP address, or both, and by configuring a refill rate \ + (requests per second) and a burst size (maximum number of \ + concurrent requests)." + ; implementation= No_fd Cli_operations.Rate_limit.create + ; flags= [] + } + ) + ; ( "rate-limit-destroy" + , { + reqd= ["uuid"] + ; optn= [] + ; help= "Destroy rate limiter" + ; implementation= No_fd Cli_operations.Rate_limit.destroy + ; flags= [] + } + ) ] let cmdtable : (string, cmd_spec) Hashtbl.t = Hashtbl.create 50 diff --git a/ocaml/xapi-cli-server/cli_operations.ml b/ocaml/xapi-cli-server/cli_operations.ml index b20ed934107..d85ffbf777e 100644 --- a/ocaml/xapi-cli-server/cli_operations.ml +++ b/ocaml/xapi-cli-server/cli_operations.ml @@ -1409,6 +1409,11 @@ let gen_cmds rpc session_id = ["uuid"; "vendor-name"; "device-name"; "pci-id"] rpc session_id ) + ; Client.Rate_limit.( + mk get_all_records_where get_by_uuid rate_limit_record "rate-limit" [] + ["uuid"; "host-ip"; "user-agent"; "burst-size"; "fill-rate"] + rpc session_id + ) ] let message_create (_ : printer) rpc session_id params = @@ -8275,3 +8280,29 @@ module VM_group = struct in Client.VM_group.destroy ~rpc ~session_id ~self:ref end + +module Rate_limit = struct + let create printer rpc session_id params = + let user_agent = get_param params "user-agent" ~default:"" in + let host_ip = get_param params "host-ip" ~default:"" in + + if user_agent = "" && host_ip = "" then + failwith "Either user-agent or host-ip must be specified" ; + + let burst_size = float_of_string (List.assoc "burst-size" params) in + let fill_rate = float_of_string (List.assoc "fill-rate" params) in + + let ref = + Client.Rate_limit.create ~rpc ~session_id ~user_agent ~host_ip ~burst_size + ~fill_rate + in + let uuid = Client.Rate_limit.get_uuid ~rpc ~session_id ~self:ref in + printer (Cli_printer.PMsg uuid) + + let destroy _printer rpc session_id params = + let ref = + Client.Rate_limit.get_by_uuid ~rpc ~session_id + ~uuid:(List.assoc "uuid" params) + in + Client.Rate_limit.destroy ~rpc ~session_id ~self:ref +end diff --git a/ocaml/xapi-cli-server/records.ml b/ocaml/xapi-cli-server/records.ml index d7f3cdf421d..8ace467e612 100644 --- a/ocaml/xapi-cli-server/records.ml +++ b/ocaml/xapi-cli-server/records.ml @@ -6030,3 +6030,41 @@ let pci_record rpc session_id pci = () ] } + +let rate_limit_record rpc session_id rate_limit = + let _ref = ref rate_limit in + let empty_record = + ToGet (fun () -> Client.Rate_limit.get_record ~rpc ~session_id ~self:!_ref) + in + let record = ref empty_record in + let x () = lzy_get record in + { + setref= + (fun r -> + _ref := r ; + record := empty_record + ) + ; setrefrec= + (fun (a, b) -> + _ref := a ; + record := Got b + ) + ; record= x + ; getref= (fun () -> !_ref) + ; fields= + [ + make_field ~name:"uuid" ~get:(fun () -> (x ()).API.rate_limit_uuid) () + ; make_field ~name:"user-agent" + ~get:(fun () -> (x ()).API.rate_limit_user_agent) + () + ; make_field ~name:"host-ip" + ~get:(fun () -> (x ()).API.rate_limit_host_ip) + () + ; make_field ~name:"burst-size" + ~get:(fun () -> string_of_float (x ()).API.rate_limit_burst_size) + () + ; make_field ~name:"fill-rate" + ~get:(fun () -> string_of_float (x ()).API.rate_limit_fill_rate) + () + ] + } diff --git a/ocaml/xapi/api_server_common.ml b/ocaml/xapi/api_server_common.ml index ef3e1f4bbc3..246370decb7 100644 --- a/ocaml/xapi/api_server_common.ml +++ b/ocaml/xapi/api_server_common.ml @@ -132,6 +132,7 @@ module Actions = struct module Observer = Xapi_observer module Host_driver = Xapi_host_driver module Driver_variant = Xapi_host_driver.Variant + module Rate_limit = Xapi_rate_limit end (** Use the server functor to make an XML-RPC dispatcher. *) diff --git a/ocaml/xapi/context.ml b/ocaml/xapi/context.ml index 419c7d3f04d..0ca0a2bc4af 100644 --- a/ocaml/xapi/context.ml +++ b/ocaml/xapi/context.ml @@ -85,6 +85,8 @@ let task_in_database ctx = Ref.is_real ctx.task_id let get_origin ctx = string_of_origin ctx.origin +let is_internal_origin ctx = ctx.origin = Internal + let database_of x = x.database (** Calls coming in from the main unix socket are pre-authenticated. diff --git a/ocaml/xapi/context.mli b/ocaml/xapi/context.mli index ac3250f8569..4d2c32c5706 100644 --- a/ocaml/xapi/context.mli +++ b/ocaml/xapi/context.mli @@ -87,6 +87,9 @@ val task_in_database : t -> bool val get_origin : t -> string (** [get_origin __context] returns a string containing the origin of [__context]. *) +val is_internal_origin : t -> bool +(** [is_internal_origin __context] returns true if the context originated from an internal operation. *) + val database_of : t -> Xapi_database.Db_ref.t (** [database_of __context] returns a database handle, which can be used by Db.* *) diff --git a/ocaml/xapi/dune b/ocaml/xapi/dune index 139bf4d0f09..52857a55961 100644 --- a/ocaml/xapi/dune +++ b/ocaml/xapi/dune @@ -65,6 +65,7 @@ exnHelper rbac_static xapi_role + xapi_rate_limit xapi_extensions db) (modes best) @@ -83,6 +84,7 @@ threads.posix fmt clock + rate-limit astring stunnel sexplib0 @@ -129,6 +131,7 @@ locking_helpers exnHelper xapi_role + xapi_rate_limit xapi_extensions db)) (libraries @@ -166,6 +169,7 @@ psq ptime ptime.clock.os + rate-limit rpclib.core rpclib.json rpclib.xml diff --git a/ocaml/xapi/message_forwarding.ml b/ocaml/xapi/message_forwarding.ml index 1d8d228cc80..6f75895a39f 100644 --- a/ocaml/xapi/message_forwarding.ml +++ b/ocaml/xapi/message_forwarding.ml @@ -6806,6 +6806,8 @@ functor in Xapi_pool_helpers.call_fn_on_slaves_then_master ~__context fn end + + module Rate_limit = Xapi_rate_limit end (* for unit tests *) diff --git a/ocaml/xapi/server_helpers.ml b/ocaml/xapi/server_helpers.ml index 0fe9383c737..638b33ca9cb 100644 --- a/ocaml/xapi/server_helpers.ml +++ b/ocaml/xapi/server_helpers.ml @@ -179,20 +179,46 @@ let do_dispatch ?session_id ?forward_op ?self:_ supports_async called_fn_name ~marshaller op_fn ) () - ) ; - (* Return task id immediately *) - Rpc.success (API.rpc_of_ref_task (Context.get_task_id __context)) + ) in - match sync_ty with - | `Sync -> - sync () - | `Async -> - let need_complete = not (Context.forwarded_task __context) in - async ~need_complete - | `InternalAsync -> - async ~need_complete:true - -(* regardless of forwarding, we are expected to complete the task *) + let handle_request () = + match sync_ty with + | `Sync -> + sync () + | `Async -> + let need_complete = not (Context.forwarded_task __context) in + async ~need_complete ; + Rpc.success (API.rpc_of_ref_task (Context.get_task_id __context)) + | `InternalAsync -> + async ~need_complete:true ; + Rpc.success (API.rpc_of_ref_task (Context.get_task_id __context)) + in + let handle_request_throttled () = + let token_cost = Xapi_rate_limit.get_token_cost called_fn_name in + let client_id = + Xapi_rate_limit.Key. + { + user_agent= Option.value http_req.user_agent ~default:"" + ; host_ip= Option.value (Context.get_client_ip __context) ~default:"" + } + in + match sync_ty with + | `Sync -> + Xapi_rate_limit.submit_sync ~client_id ~callback:sync token_cost + | `Async -> + let need_complete = not (Context.forwarded_task __context) in + Xapi_rate_limit.submit ~client_id + ~callback:(fun () -> async ~need_complete) + token_cost ; + Rpc.success (API.rpc_of_ref_task (Context.get_task_id __context)) + | `InternalAsync -> + async ~need_complete:true ; + Rpc.success (API.rpc_of_ref_task (Context.get_task_id __context)) + in + if Context.is_internal_origin __context then + handle_request () + else + handle_request_throttled () (* in the following functions, it is our responsibility to complete any tasks we create *) let exec_with_new_task ?http_other_config ?quiet ?subtask_of ?session_id diff --git a/ocaml/xapi/xapi.ml b/ocaml/xapi/xapi.ml index 785950c384e..d63844ceb59 100644 --- a/ocaml/xapi/xapi.ml +++ b/ocaml/xapi/xapi.ml @@ -884,11 +884,9 @@ let listen_unix_socket sock_path = Unixext.mkdir_safe (Filename.dirname sock_path) 0o700 ; Unixext.unlink_safe sock_path ; let domain_sock = Xapi_http.bind (Unix.ADDR_UNIX sock_path) in - ignore - (Http_svr.start - ~conn_limit:!Xapi_globs.conn_limit_unix - Xapi_http.server domain_sock - ) + Http_svr.start + ~conn_limit:!Xapi_globs.conn_limit_unix + Xapi_http.server domain_sock let set_stunnel_timeout () = try @@ -1169,6 +1167,10 @@ let server_init () = , [] , fun () -> report_tls_verification ~__context ) + ; ( "Registering rate limits" + , [Startup.OnlyMaster] + , fun () -> Xapi_rate_limit.register ~__context + ) ; ( "Remote requests" , [Startup.OnThread] , Remote_requests.handle_requests diff --git a/ocaml/xapi/xapi_http.ml b/ocaml/xapi/xapi_http.ml index 964983d8eda..406197df440 100644 --- a/ocaml/xapi/xapi_http.ml +++ b/ocaml/xapi/xapi_http.ml @@ -351,25 +351,43 @@ let add_handler (name, handler) = failwith (Printf.sprintf "Unregistered HTTP handler: %s" name) in let check_rbac = Rbac.is_rbac_enabled_for_http_action name in - let h req ic context = - let client = - Http_svr.(client_of_req_and_fd req ic |> Option.map string_of_client) + let rate_limit (client_id_opt : Xapi_rate_limit.Key.t option) handler () = + if List.mem name Datamodel.custom_rate_limit_http_actions then + handler () + else + match client_id_opt with + | None -> + handler () + | Some ({user_agent; host_ip} as client_id) -> + debug "Rate limiting handler %s with user_agent %s host_ip %s" name + user_agent host_ip ; + Xapi_rate_limit.submit ~client_id ~callback:handler + Xapi_rate_limit.median_token_cost + in + let h req ic () = + let client_info = Http_svr.client_of_req_and_fd req ic in + let client = Option.map Http_svr.string_of_client client_info in + let client_id = + match (req.Http.Request.user_agent, client_info) with + | Some user_agent, Some (_, ip) -> + Some Xapi_rate_limit.Key.{user_agent; host_ip= Ipaddr.to_string ip} + | _ -> + None in + let rate_limited_handler = rate_limit client_id (handler req ic) in Debug.with_thread_associated ?client name (fun () -> try if check_rbac then ( try (* session and rbac checks *) - assert_credentials_ok name req - ~fn:(fun () -> handler req ic context) - ic + assert_credentials_ok name req ~fn:rate_limited_handler ic with e -> debug "Leaving RBAC-handler in xapi_http after: %s" (ExnHelper.string_of_exn e) ; raise e ) else (* no rbac checks *) - handler req ic context + rate_limited_handler () with Api_errors.Server_error (name, params) as e -> error "Unhandled Api_errors.Server_error(%s, [ %s ])" name (String.concat "; " params) ; diff --git a/ocaml/xapi/xapi_rate_limit.ml b/ocaml/xapi/xapi_rate_limit.ml new file mode 100644 index 00000000000..b745b222f2f --- /dev/null +++ b/ocaml/xapi/xapi_rate_limit.ml @@ -0,0 +1,744 @@ +(* + * Copyright (C) Citrix Systems Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) +module D = Debug.Make (struct let name = "xapi_rate_limit" end) + +module Bucket_table = Rate_limit.Bucket_table +module Key = Rate_limit.Bucket_table.Key + +let bucket_table = Bucket_table.create () + +let submit_sync ~client_id ~callback amount = + Bucket_table.submit_sync bucket_table ~client_id ~callback amount + +let submit ~client_id ~callback amount = + Bucket_table.submit bucket_table ~client_id ~callback amount + +let peek ~client_id = Bucket_table.peek bucket_table ~client_id + +let create ~__context ~user_agent ~host_ip ~burst_size ~fill_rate = + if user_agent = "" && host_ip = "" then + raise + Api_errors.( + Server_error + (invalid_value, ["Expected user_agent or host_ip to be nonempty"]) + ) ; + let client_id = Key.{user_agent; host_ip} in + if Bucket_table.mem bucket_table ~client_id then + raise + Api_errors.( + Server_error + ( map_duplicate_key + , ["user_agent"; user_agent; "user_agent already registered"] + ) + ) ; + let uuid = Uuidx.make () in + let ref = Ref.make () in + let add_bucket_succeeded = + Bucket_table.add_bucket bucket_table ~client_id ~burst_size ~fill_rate + in + match add_bucket_succeeded with + | true -> + Db.Rate_limit.create ~__context ~ref ~uuid:(Uuidx.to_string uuid) + ~user_agent ~host_ip ~burst_size ~fill_rate ; + ref + | false -> + raise + Api_errors.( + Server_error + ( invalid_value + , [ + "fill_rate" + ; string_of_float fill_rate + ; "Fill rate cannot be 0 or negative" + ] + ) + ) + +let destroy ~__context ~self = + let record = Db.Rate_limit.get_record ~__context ~self in + let client_id = + Key. + { + user_agent= record.rate_limit_user_agent + ; host_ip= record.rate_limit_host_ip + } + in + Bucket_table.delete_bucket bucket_table ~client_id ; + Db.Rate_limit.destroy ~__context ~self + +let register ~__context = + List.iter + (fun (_, bucket) -> + let client_id = + Key. + { + user_agent= bucket.API.rate_limit_user_agent + ; host_ip= bucket.API.rate_limit_host_ip + } + in + ignore + (Bucket_table.add_bucket bucket_table ~client_id + ~fill_rate:bucket.API.rate_limit_fill_rate + ~burst_size:bucket.API.rate_limit_burst_size + ) + ) + (Db.Rate_limit.get_all_records ~__context) + +(* These are the average times taken by xapi to fulfil the requests *) +let token_costs = + Hashtbl.of_seq + (List.to_seq + [ + ("VDI.pool_migrate", 257.909552) + ; ("VM.migrate_send", 186.357493) + ; ("host.apply_edition", 49.457409) + ; ("VM.suspend", 44.825024) + ; ("VM.resume_on", 43.383270) + ; ("SR.probe", 41.225226) + ; ("VM.copy", 32.585879) + ; ("pool.enable_ha", 28.515320) + ; ("VM.checkpoint", 23.014775) + ; ("host.ha_join_liveset", 20.912049) + ; ("Cluster.pool_create", 19.861666) + ; ("VDI.copy", 19.146503) + ; ("VM.pool_migrate", 16.976464) + ; ("VM.resume", 16.593534) + ; ("SR.destroy", 16.158883) + ; ("Cluster_host.create", 15.364815) + ; ("event.from", 13.125585) + ; ("pool.management_reconfigure", 11.106379) + ; ("pool.join", 9.706707) + ; ("pool.disable_ha", 9.662020) + ; ("host.prepare_for_poweroff", 9.493340) + ; ("VM.set_memory_dynamic_range", 8.171569) + ; ("host.evacuate", 7.635096) + ; ("VM.clean_reboot", 6.922801) + ; ("VM.restart_device_models", 6.726358) + ; ("pool_update.apply", 6.677579) + ; ("Bond.create", 6.063720) + ; ("VM.clean_shutdown", 5.779665) + ; ("VM.revert", 5.167195) + ; ("host.install_server_certificate", 5.165420) + ; ("pool.eject", 4.460441) + ; ("Cluster.create", 4.424762) + ; ("pool.sync_updates", 4.387346) + ; ("host.apply_updates", 3.887800) + ; ("SR.probe_ext", 3.778658) + ; ("host.ha_wait_for_shutdown_via_statefile", 3.677161) + ; ("pool_update.precheck", 3.306154) + ; ("event.next", 3.253017) + ; ("VDI.snapshot", 2.983962) + ; ("pool_update.introduce", 2.753982) + ; ("pool.enable_external_auth", 2.226998) + ; ("VM.start_on", 2.190845) + ; ("VM.hard_reboot", 2.139073) + ; ("SR.create", 2.103516) + ; ("VM.hard_shutdown", 2.066366) + ; ("pool.designate_new_master", 1.967346) + ; ("VM.start", 1.830110) + ; ("VDI.clone", 1.712122) + ; ("host.ha_release_resources", 1.646541) + ; ("VM.snapshot", 1.459288) + ; ("pool.is_slave", 1.445609) + ; ("pool.recover_slaves", 1.427092) + ; ("host.preconfigure_ha", 1.410345) + ; ("pool_update.detach", 1.350796) + ; ("pool_update.attach", 1.338446) + ; ("host.update_master", 1.336511) + ; ("PBD.plug", 1.267668) + ; ("Repository.apply", 1.237862) + ; ("pool.emergency_reset_master", 1.232223) + ; ("VBD.plug", 1.144580) + ; ("host.commit_new_master", 1.144527) + ; ("SR.scan", 1.133962) + ; ("host.enable_external_auth", 1.107988) + ; ("VBD.unplug", 1.070053) + ; ("pool_update.pool_clean", 1.044175) + ; ("VM.clone", 1.019502) + ; ("VM.provision", 0.981767) + ; ("PIF.reconfigure_ip", 0.909502) + ; ("pool.create_VLAN_from_PIF", 0.875542) + ; ("pool.apply_edition", 0.829413) + ; ("pool.disable_external_auth", 0.745834) + ; ("VM.pool_migrate_complete", 0.742376) + ; ("host.call_plugin", 0.741278) + ; ("VLAN.create", 0.616747) + ; ("VDI.create", 0.606362) + ; ("host.update_firewalld_service_status", 0.579426) + ; ("VDI.destroy", 0.550137) + ; ("VIF.plug", 0.536839) + ; ("host.set_iscsi_iqn", 0.487450) + ; ("SR.update", 0.482580) + ; ("VDI.resize", 0.437588) + ; ("host.management_reconfigure", 0.369713) + ; ("VIF.unplug", 0.340281) + ; ("host.set_https_only", 0.337233) + ; ("PIF.plug", 0.333846) + ; ("host.disable_external_auth", 0.323262) + ; ("VDI.set_name_label", 0.316701) + ; ("VDI.set_name_description", 0.272491) + ; ("PIF.scan", 0.264133) + ; ("PBD.unplug", 0.254334) + ; ("pool_update.resync_host", 0.215983) + ; ("VDI.generate_config", 0.199678) + ; ("PIF.reconfigure_ipv6", 0.192960) + ; ("host.get_system_status_capabilities", 0.185781) + ; ("host.request_backup", 0.179546) + ; ("host.signal_networking_change", 0.178460) + ; ("PIF.destroy", 0.175272) + ; ("pool.enable_tls_verification", 0.160637) + ; ("pool.set_repositories", 0.146938) + ; ("pool.exchange_certificates_on_join", 0.138671) + ; ("subject.create", 0.118591) + ; ("host.shutdown_agent", 0.118081) + ; ("VLAN.destroy", 0.116394) + ; ("VM.unpause", 0.110404) + ; ("session.change_password", 0.109895) + ; ("host.syslog_reconfigure", 0.099574) + ; ("pool.exchange_ca_certificates_on_join", 0.093144) + ; ("pool.sync_database", 0.086285) + ; ("VM.assert_can_migrate_sender", 0.077286) + ; ("VM.pause", 0.074863) + ; ("pool.hello", 0.072118) + ; ("host.request_config_file_sync", 0.062448) + ; ("network.attach_for_vm", 0.051612) + ; ("host.migrate_receive", 0.051597) + ; ("auth.get_subject_identifier", 0.044924) + ; ("host.cert_distrib_atom", 0.042307) + ; ("PBD.create", 0.035094) + ; ("Repository.apply_livepatch", 0.028736) + ; ("session.login_with_password", 0.027734) + ; ("host.update_pool_secret", 0.027330) + ; ("host.enable", 0.027242) + ; ("host.get_data_sources", 0.026641) + ; ("Observer.register", 0.025281) + ; ("SR.forget", 0.021780) + ; ("host.disable", 0.021741) + ; ("Observer.set_attributes", 0.020536) + ; ("VDI.get_all_records", 0.020524) + ; ("host.set_multipathing", 0.019336) + ; ("Observer.set_endpoints", 0.018160) + ; ("Observer.create", 0.015905) + ; ("Observer.set_enabled", 0.015219) + ; ("host.reboot", 0.015137) + ; ("VM.set_VCPUs_number_live", 0.014856) + ; ("host.get_sm_diagnostics", 0.013681) + ; ("host.set_license_params", 0.012687) + ; ("role.get_all_records", 0.012687) + ; ("host.apply_guest_agent_config", 0.010359) + ; ("VM.atomic_set_resident_on", 0.009822) + ; ("subject.destroy", 0.005524) + ; ("message.get", 0.005002) + ; ("PBD.destroy", 0.004741) + ; ("host.get_diagnostic_timing_stats", 0.004670) + ; ("VDI.get_all_records_where", 0.004475) + ; ("host.write_uefi_certificates_to_disk", 0.004337) + ; ("host.backup_rrds", 0.004214) + ; ("VIF.create", 0.004065) + ; ("session.slave_login", 0.003956) + ; ("host.get_thread_diagnostics", 0.003906) + ; ("host.enable_local_storage_caching", 0.003829) + ; ("SR.get_data_sources", 0.003741) + ; ("pool.add_repository", 0.003725) + ; ("task.cancel", 0.003540) + ; ("VM.set_NVRAM_EFI_variables", 0.003537) + ; ("VBD.create", 0.003263) + ; ("session.slave_local_login_with_password", 0.003192) + ; ("VM.destroy", 0.003171) + ; ("VDI.forget", 0.003090) + ; ("SR.set_name_description", 0.002877) + ; ("VBD.destroy", 0.002780) + ; ("host.create", 0.002697) + ; ("VM.query_data_source", 0.002616) + ; ("PUSB.scan", 0.002592) + ; ("VTPM.get_contents", 0.002470) + ; ("pool.get_license_state", 0.002393) + ; ("host.get_vms_which_prevent_evacuation", 0.002368) + ; ("host.add_to_guest_VCPUs_params", 0.002341) + ; ("VM.assert_can_boot_here", 0.002287) + ; ("VM.set_memory_limits", 0.002266) + ; ("VM.update_allowed_operations", 0.002096) + ; ("VM.get_possible_hosts", 0.002050) + ; ("VM.get_all_records_where", 0.002033) + ; ("Diagnostics.db_stats", 0.002011) + ; ("VM.set_actions_after_crash", 0.001944) + ; ("VIF.destroy", 0.001822) + ; ("VDI.db_forget", 0.001783) + ; ("host.set_name_description", 0.001782) + ; ("message.destroy_many", 0.001770) + ; ("VM.create", 0.001747) + ; ("host.tickle_heartbeat", 0.001734) + ; ("VBD.eject", 0.001714) + ; ("VGPU.atomic_set_resident_on", 0.001573) + ; ("Diagnostics.network_stats", 0.001546) + ; ("VM.get_all_records", 0.001503) + ; ("event.inject", 0.001468) + ; ("VM.set_VCPUs_max", 0.001436) + ; ("pool.reset_telemetry_uuid", 0.001400) + ; ("PBD.set_device_config", 0.001381) + ; ("VTPM.set_contents", 0.001340) + ; ("session.get_uuid", 0.001310) + ; ("PIF.pool_introduce", 0.001294) + ; ("host.allocate_resources_for_vm", 0.001254) + ; ("host_cpu.get_all_records_where", 0.001248) + ; ("VM.set_has_vendor_device", 0.001242) + ; ("VBD.assert_attachable", 0.001226) + ; ("VDI.set_cbt_enabled", 0.001149) + ; ("message.create", 0.001136) + ; ("VM.add_to_blocked_operations", 0.001111) + ; ("host.set_crash_dump_sr", 0.001095) + ; ("PCI.get_all_records", 0.001091) + ; ("network.create", 0.001089) + ; ("VGPU.create", 0.001088) + ; ("VM.create_new_blob", 0.001077) + ; ("VDI.db_introduce", 0.001069) + ; ("VM.set_name_label", 0.001032) + ; ("pool.set_ext_auth_cache_expiry", 0.001027) + ; ("message.get_all_records_where", 0.001022) + ; ("host.remove_from_license_server", 0.001021) + ; ("VM.remove_from_blocked_operations", 0.000958) + ; ("VTPM.create", 0.000923) + ; ("host.sync_pif_currently_attached", 0.000922) + ; ("host.add_to_logging", 0.000906) + ; ("VM.set_is_default_template", 0.000900) + ; ("host.add_to_license_server", 0.000897) + ; ("VDI.read_database_pool_uuid", 0.000894) + ; ("VM_metrics.get_all_records", 0.000883) + ; ("host.remove_from_logging", 0.000880) + ; ("VLAN.pool_introduce", 0.000864) + ; ("Repository.introduce_bundle", 0.000863) + ; ("Repository.introduce", 0.000863) + ; ("host.set_license_server", 0.000827) + ; ("host.set_suspend_image_sr", 0.000810) + ; ("message.get_all_records", 0.000799) + ; ("VM.get_HVM_boot_params", 0.000787) + ; ("PGPU.set_GPU_group", 0.000767) + ; ("host.compute_free_memory", 0.000764) + ; ("SR.introduce", 0.000747) + ; ("pool_update.destroy", 0.000725) + ; ("host.add_to_other_config", 0.000709) + ; ("VM.get_cooperative", 0.000696) + ; ("VTPM.destroy", 0.000680) + ; ("GPU_group.set_allocation_algorithm", 0.000669) + ; ("host_metrics.get_all_records", 0.000669) + ; ("VM.set_actions_after_shutdown", 0.000638) + ; ("Repository.forget", 0.000621) + ; ("pool.set_ext_auth_cache_size", 0.000615) + ; ("blob.create", 0.000614) + ; ("task.set_status", 0.000605) + ; ("VM.set_other_config", 0.000596) + ; ("pool.set_ext_auth_cache_enabled", 0.000592) + ; ("host.get_tracked_user_agents", 0.000586) + ; ("PIF.set_disallow_unplug", 0.000586) + ; ("host.query_data_source", 0.000580) + ; ("pool_update.get_all_records_where", 0.000560) + ; ("pool.disable_repository_proxy", 0.000551) + ; ("network.destroy", 0.000524) + ; ("host.set_other_config", 0.000521) + ; ("host.get_all_records", 0.000501) + ; ("VDI.pool_introduce", 0.000481) + ; ("pool.set_telemetry_next_collection", 0.000458) + ; ("host.get_management_interface", 0.000445) + ; ("VM.set_name_description", 0.000438) + ; ("GPU_group.update_enabled_VGPU_types", 0.000435) + ; ("pool.initial_auth", 0.000412) + ; ("pool.set_default_SR", 0.000406) + ; ("GPU_group.update_supported_VGPU_types", 0.000400) + ; ("VM.add_to_other_config", 0.000397) + ; ("VM.get_snapshots", 0.000385) + ; ("VM.remove_from_platform", 0.000383) + ; ("SM.get_all_records", 0.000380) + ; ("host.get_all_records_where", 0.000378) + ; ("network.pool_introduce", 0.000375) + ; ("VM.remove_from_other_config", 0.000365) + ; ("VM.set_VCPUs_at_startup", 0.000362) + ; ("SR.get_all_records", 0.000360) + ; ("secret.create", 0.000351) + ; ("host_cpu.get_all_records", 0.000340) + ; ("secret.destroy", 0.000337) + ; ("VGPU_type.get_all_records_where", 0.000329) + ; ("PIF.get_network", 0.000325) + ; ("event.register", 0.000324) + ; ("PGPU.get_by_uuid", 0.000322) + ; ("PIF.get_all_records_where", 0.000321) + ; ("VM.remove_from_HVM_boot_params", 0.000320) + ; ("PBD.add_to_other_config", 0.000313) + ; ("PIF.add_to_other_config", 0.000312) + ; ("task.create", 0.000311) + ; ("PBD.remove_from_other_config", 0.000310) + ; ("pool.set_crash_dump_SR", 0.000308) + ; ("pool.add_to_license_server", 0.000305) + ; ("host.sync_vlans", 0.000303) + ; ("subject.add_to_roles", 0.000303) + ; ("VM.add_to_platform", 0.000296) + ; ("VDI.get_by_name_label", 0.000291) + ; ("pool.get_all_records_where", 0.000280) + ; ("VM.add_to_HVM_boot_params", 0.000278) + ; ("VM.set_suspend_SR", 0.000277) + ; ("pool.set_suspend_image_SR", 0.000274) + ; ("VDI.set_snapshot_of", 0.000271) + ; ("host.get_record", 0.000270) + ; ("network.add_to_other_config", 0.000269) + ; ("PIF.get_all_records", 0.000264) + ; ("network.remove_from_other_config", 0.000261) + ; ("VM.set_HVM_boot_policy", 0.000260) + ; ("VDI.add_to_sm_config", 0.000260) + ; ("VDI.set_managed", 0.000256) + ; ("task.destroy", 0.000250) + ; ("pool.add_to_other_config", 0.000245) + ; ("VBD.get_all_records", 0.000244) + ; ("VM.set_is_a_template", 0.000236) + ; ("VM.set_PV_args", 0.000232) + ; ("VDI.remove_from_sm_config", 0.000230) + ; ("SR.get_all_records_where", 0.000227) + ; ("secret.introduce", 0.000225) + ; ("VDI.set_sm_config", 0.000222) + ; ("task.set_other_config", 0.000219) + ; ("SR.set_other_config", 0.000217) + ; ("host.set_ssl_legacy", 0.000214) + ; ("pool.remove_from_other_config", 0.000211) + ; ("SR.add_to_other_config", 0.000208) + ; ("VM.get_record", 0.000206) + ; ("pool.set_other_config", 0.000205) + ; ("VDI.add_to_other_config", 0.000203) + ; ("pool.get_all_records", 0.000200) + ; ("pool.detect_nonhomogeneous_external_auth", 0.000197) + ; ("SR.set_sm_config", 0.000196) + ; ("network.set_MTU", 0.000194) + ; ("SR.set_virtual_allocation", 0.000191) + ; ("network.get_all_records", 0.000187) + ; ("VM.get_allowed_VBD_devices", 0.000186) + ; ("message.get_by_uuid", 0.000186) + ; ("session.logout", 0.000182) + ; ("VDI.set_physical_utilisation", 0.000180) + ; ("pool.remove_from_license_server", 0.000178) + ; ("host.sync_tunnels", 0.000171) + ; ("SR.add_to_sm_config", 0.000168) + ; ("VBD.add_to_other_config", 0.000168) + ; ("host.emergency_clear_mandatory_guidance", 0.000167) + ; ("PIF_metrics.get_all_records", 0.000164) + ; ("PBD.get_by_uuid", 0.000163) + ; ("task.get_all_records_where", 0.000163) + ; ("pool.get_record", 0.000162) + ; ("VIF.get_all_records", 0.000161) + ; ("SM.add_to_other_config", 0.000161) + ; ("SR.remove_from_sm_config", 0.000160) + ; ("task.set_progress", 0.000160) + ; ("VM.get_NVRAM", 0.000155) + ; ("message.get_record", 0.000151) + ; ("host.remove_from_other_config", 0.000151) + ; ("event.unregister", 0.000141) + ; ("VGPU.get_PCI", 0.000135) + ; ("VM.get_by_name_label", 0.000135) + ; ("PBD.get_all_records", 0.000134) + ; ("SR.set_physical_utilisation", 0.000132) + ; ("host.ha_xapi_healthcheck", 0.000132) + ; ("VDI.set_virtual_size", 0.000131) + ; ("Certificate.get_all_records", 0.000129) + ; ("PGPU.get_all_records_where", 0.000128) + ; ("VIF.get_all_records_where", 0.000125) + ; ("session.slave_local_login", 0.000125) + ; ("PBD.get_all", 0.000124) + ; ("VBD.get_all_records_where", 0.000122) + ; ("SR.set_physical_size", 0.000118) + ; ("host.emergency_ha_disable", 0.000117) + ; ("VDI.set_read_only", 0.000117) + ; ("network.get_all_records_where", 0.000114) + ; ("task.get_record", 0.000113) + ; ("pool.get_suspend_image_SR", 0.000112) + ; ("VLAN.get_by_uuid", 0.000111) + ; ("SM.get_all_records_where", 0.000110) + ; ("network.get_name_label", 0.000110) + ; ("SM.remove_from_other_config", 0.000110) + ; ("VDI.remove_from_other_config", 0.000106) + ; ("VM.get_VBDs", 0.000105) + ; ("blob.get_record", 0.000104) + ; ("task.get_name_label", 0.000104) + ; ("VM_guest_metrics.get_all_records", 0.000102) + ; ("task.get_all_records", 0.000101) + ; ("VDI.get_uuid", 0.000101) + ; ("SR.get_record", 0.000099) + ; ("PBD.get_all_records_where", 0.000099) + ; ("pool.get_policy_no_vendor_device", 0.000098) + ; ("secret.set_value", 0.000097) + ; ("VBD.get_record", 0.000095) + ; ("blob.get_all_records", 0.000094) + ; ("host.get_external_auth_type", 0.000094) + ; ("SDN_controller.get_all_records", 0.000093) + ; ("VIF.get_record", 0.000092) + ; ("secret.get_all_records", 0.000092) + ; ("host.propose_new_master", 0.000087) + ; ("console.get_record", 0.000087) + ; ("PIF.get_record", 0.000086) + ; ("pool.get_other_config", 0.000086) + ; ("VM_guest_metrics.get_record", 0.000086) + ; ("PIF.get_all_where", 0.000085) + ; ("VM_metrics.get_record", 0.000084) + ; ("pool_patch.get_all_records", 0.000083) + ; ("VLAN.get_all_records_where", 0.000083) + ; ("session.get_is_local_superuser", 0.000081) + ; ("task.get_progress", 0.000081) + ; ("Repository.get_by_name_label", 0.000081) + ; ("VBD.get_VDI", 0.000081) + ; ("VBD.get_mode", 0.000081) + ; ("SR.get_VDIs", 0.000079) + ; ("host.is_in_emergency_mode", 0.000079) + ; ("host_crashdump.get_all_records_where", 0.000078) + ; ("host_metrics.get_record", 0.000077) + ; ("Repository.get_all_records_where", 0.000077) + ; ("host.get_ssh_enabled", 0.000076) + ; ("Bond.get_all_records_where", 0.000076) + ; ("VM.get_all", 0.000076) + ; ("pool_update.get_record", 0.000075) + ; ("host.get_capabilities", 0.000075) + ; ("task.get_created", 0.000074) + ; ("VM.get_VIFs", 0.000073) + ; ("VM.get_allowed_VIF_devices", 0.000073) + ; ("VM.get_other_config", 0.000073) + ; ("pool.get_tls_verification_enabled", 0.000072) + ; ("host.set_name_label", 0.000071) + ; ("PBD.get_all_where", 0.000071) + ; ("Cluster.get_all_records_where", 0.000070) + ; ("console.get_all_records", 0.000070) + ; ("VM.get_allowed_operations", 0.000070) + ; ("VM.get_recommendations", 0.000070) + ; ("session.get_rbac_permissions", 0.000069) + ; ("Repository.get_by_uuid", 0.000069) + ; ("task.get_uuid", 0.000069) + ; ("network_sriov.get_all_records_where", 0.000068) + ; ("host.get_all_where", 0.000068) + ; ("PUSB.get_all_records", 0.000068) + ; ("VBD.get_userdevice", 0.000068) + ; ("VDI.get_managed", 0.000068) + ; ("pool.get_default_SR", 0.000067) + ; ("host.get_supported_bootloaders", 0.000067) + ; ("VBD.get_other_config", 0.000067) + ; ("VBD.get_by_uuid", 0.000066) + ; ("VGPU.get_all", 0.000066) + ; ("GPU_group.get_all_records", 0.000066) + ; ("PGPU.get_all_records", 0.000066) + ; ("PIF.get_by_uuid", 0.000065) + ; ("network.get_record", 0.000064) + ; ("Repository.get_record", 0.000064) + ; ("host.get_PBDs", 0.000063) + ; ("VDI.get_is_a_snapshot", 0.000062) + ; ("SR.get_all", 0.000062) + ; ("PBD.get_record", 0.000062) + ; ("VM.get_name_label", 0.000062) + ; ("GPU_group.get_all_records_where", 0.000062) + ; ("SR.get_by_name_label", 0.000061) + ; ("VMSS.get_all_records", 0.000060) + ; ("network.get_by_name_label", 0.000060) + ; ("pool.get_all", 0.000059) + ; ("Bond.get_record", 0.000059) + ; ("network.get_bridge", 0.000059) + ; ("SR.get_PBDs", 0.000059) + ; ("pool_patch.get_all_records_where", 0.000057) + ; ("VIF.get_by_uuid", 0.000057) + ; ("pool.get_by_uuid", 0.000056) + ; ("Cluster.get_uuid", 0.000056) + ; ("VM.get_is_a_snapshot", 0.000056) + ; ("host.get_other_config", 0.000056) + ; ("VM.get_uuid", 0.000055) + ; ("VM.get_name_description", 0.000055) + ; ("Cluster.get_all", 0.000055) + ; ("PBD.get_SR", 0.000054) + ; ("DR_task.get_all_records", 0.000054) + ; ("host.get_license_params", 0.000054) + ; ("VMSS.get_all", 0.000053) + ; ("PBD.get_host", 0.000053) + ; ("network.get_PIFs", 0.000053) + ; ("Bond.get_by_uuid", 0.000053) + ; ("VM.get_by_uuid", 0.000053) + ; ("VDI.get_record", 0.000052) + ; ("host.get_by_uuid", 0.000052) + ; ("VUSB.get_all_records", 0.000052) + ; ("VM.get_suspend_VDI", 0.000052) + ; ("PIF_metrics.get_record", 0.000052) + ; ("host.get_API_version_minor", 0.000051) + ; ("VDI.get_metadata_of_pool", 0.000051) + ; ("PCI.get_pci_id", 0.000051) + ; ("task.get_result", 0.000051) + ; ("VDI.get_by_uuid", 0.000051) + ; ("VDI.get_sm_config", 0.000051) + ; ("network.get_by_uuid", 0.000050) + ; ("VBD.get_device", 0.000050) + ; ("host.get_name_label", 0.000050) + ; ("PVS_cache_storage.get_all_records", 0.000050) + ; ("VM.get_is_a_template", 0.000050) + ; ("VLAN.get_record", 0.000049) + ; ("VDI.get_on_boot", 0.000049) + ; ("SR.get_virtual_allocation", 0.000049) + ; ("SR.get_by_uuid", 0.000049) + ; ("VBD.get_type", 0.000048) + ; ("SM.get_by_name_label", 0.000048) + ; ("VIF.get_device", 0.000048) + ; ("Certificate.get_all_records_where", 0.000048) + ; ("Observer.get_by_uuid", 0.000048) + ; ("Feature.get_all_records", 0.000048) + ; ("pool_update.get_all_records", 0.000048) + ; ("host.get_PGPUs", 0.000048) + ; ("SR.get_sm_config", 0.000048) + ; ("host_cpu.get_record", 0.000048) + ; ("VBD.get_uuid", 0.000047) + ; ("host.get_virtual_hardware_platform_versions", 0.000047) + ; ("pool.get_ha_enabled", 0.000047) + ; ("host.get_editions", 0.000046) + ; ("VM.get_memory_static_max", 0.000046) + ; ("VGPU.get_all_records", 0.000046) + ; ("host.get_API_version_major", 0.000046) + ; ("SDN_controller.get_all", 0.000046) + ; ("PBD.get_other_config", 0.000046) + ; ("host.get_cpu_info", 0.000046) + ; ("network.get_MTU", 0.000045) + ; ("VM.get_VGPUs", 0.000045) + ; ("host.get_all", 0.000045) + ; ("VLAN.get_all_records", 0.000045) + ; ("Repository.get_uuid", 0.000045) + ; ("VDI.get_location", 0.000045) + ; ("host.get_uuid", 0.000045) + ; ("network.get_all", 0.000044) + ; ("VGPU.get_all_records_where", 0.000044) + ; ("session.get_this_host", 0.000044) + ; ("host_patch.get_all_records", 0.000044) + ; ("SR.get_other_config", 0.000043) + ; ("host.get_software_version", 0.000043) + ; ("pool.get_telemetry_next_collection", 0.000043) + ; ("Bond.get_uuid", 0.000043) + ; ("Observer.get_all_records", 0.000043) + ; ("secret.get_by_uuid", 0.000043) + ; ("VDI.get_name_label", 0.000043) + ; ("pool_update.get_by_uuid", 0.000042) + ; ("SR.get_type", 0.000042) + ; ("VIF.get_uuid", 0.000042) + ; ("SM.get_record", 0.000041) + ; ("SM.get_driver_filename", 0.000041) + ; ("Cluster_host.get_all", 0.000041) + ; ("subject.get_record", 0.000041) + ; ("pool.get_master", 0.000040) + ; ("VDI.get_name_description", 0.000040) + ; ("task.get_status", 0.000040) + ; ("VTPM.get_by_uuid", 0.000040) + ; ("host.get_external_auth_service_name", 0.000040) + ; ("VDI.get_SR", 0.000040) + ; ("pool.get_custom_uefi_certificates", 0.000040) + ; ("host.get_pending_guidances", 0.000040) + ; ("host.get_suspend_image_sr", 0.000040) + ; ("host.get_ssh_enabled_timeout", 0.000040) + ; ("VLAN.get_untagged_PIF", 0.000039) + ; ("PBD.get_device_config", 0.000039) + ; ("task.get_other_config", 0.000039) + ; ("VDI.get_other_config", 0.000039) + ; ("network.get_uuid", 0.000039) + ; ("SR.get_uuid", 0.000039) + ; ("host.get_bios_strings", 0.000039) + ; ("pool.get_uefi_certificates", 0.000039) + ; ("VM.get_power_state", 0.000039) + ; ("Bond.get_all_records", 0.000039) + ; ("pool_update.get_uuid", 0.000039) + ; ("pool.get_cpu_info", 0.000039) + ; ("pool.get_name_label", 0.000039) + ; ("pool.get_current_operations", 0.000039) + ; ("host.get_enabled", 0.000038) + ; ("PBD.get_currently_attached", 0.000038) + ; ("PIF.get_primary_address_type", 0.000038) + ; ("PIF.get_VLAN", 0.000038) + ; ("host.get_updates", 0.000037) + ; ("SR.get_name_label", 0.000037) + ; ("host.get_console_idle_timeout", 0.000037) + ; ("VM.get_VTPMs", 0.000037) + ; ("GPU_group.get_record", 0.000037) + ; ("VDI.get_allow_caching", 0.000036) + ; ("VDI.get_snapshot_time", 0.000036) + ; ("host.get_ssh_auto_mode", 0.000036) + ; ("host.get_address", 0.000036) + ; ("PGPU.get_record", 0.000036) + ; ("role.get_by_name_label", 0.000036) + ; ("pool.get_license_server", 0.000035) + ; ("VIF.get_network", 0.000035) + ; ("VGPU_type.get_all_records", 0.000035) + ; ("network.get_VIFs", 0.000035) + ; ("VGPU_type.get_record", 0.000035) + ; ("VDI.get_type", 0.000035) + ; ("SR.get_content_type", 0.000034) + ; ("PGPU.get_PCI", 0.000034) + ; ("Cluster_host.get_all_records", 0.000034) + ; ("Cluster.get_all_records", 0.000034) + ; ("crashdump.get_all_records", 0.000034) + ; ("PIF.get_uuid", 0.000034) + ; ("secret.get_value", 0.000034) + ; ("SR.get_name_description", 0.000034) + ; ("VM.get_domid", 0.000034) + ; ("host.get_license_server", 0.000034) + ; ("PVS_proxy.get_all_records", 0.000033) + ; ("pool.get_uuid", 0.000033) + ; ("Repository.get_all_records", 0.000033) + ; ("VGPU.get_scheduled_to_be_resident_on", 0.000033) + ; ("network_sriov.get_all_records", 0.000033) + ; ("Observer.get_uuid", 0.000032) + ; ("task.get_error_info", 0.000032) + ; ("tunnel.get_all_records", 0.000032) + ; ("pool.get_vswitch_controller", 0.000032) + ; ("VM_group.get_all_records", 0.000031) + ; ("VM_appliance.get_all_records", 0.000031) + ; ("PCI.get_record", 0.000030) + ; ("VGPU.get_device", 0.000030) + ; ("host_crashdump.get_all_records", 0.000030) + ; ("subject.get_all_records", 0.000030) + ; ("VTPM.get_all_records", 0.000029) + ; ("USB_group.get_all_records", 0.000029) + ; ("pool_update.get_name_label", 0.000029) + ; ("pool_update.get_vdi", 0.000029) + ; ("VGPU_type.get_by_uuid", 0.000028) + ; ("VBD.get_empty", 0.000028) + ; ("VGPU.get_record", 0.000027) + ; ("VGPU.get_by_uuid", 0.000026) + ; ("subject.get_all", 0.000026) + ; ("VM.get_suspend_SR", 0.000025) + ; ("GPU_group.get_by_uuid", 0.000025) + ; ("subject.get_all_records_where", 0.000025) + ; ("session.local_logout", 0.000025) + ; ("PVS_server.get_all_records", 0.000024) + ; ("PVS_site.get_all_records", 0.000024) + ; ("subject.get_uuid", 0.000023) + ; ("secret.get_uuid", 0.000023) + ; ("VGPU.get_uuid", 0.000023) + ; ("VTPM.get_record", 0.000023) + ; ("subject.get_by_uuid", 0.000022) + ; ("VTPM.get_all_records_where", 0.000022) + ; ("VM.get_is_control_domain", 0.000021) + ; ("pool_update.get_installation_size", 0.000021) + ; ("VGPU_type.get_model_name", 0.000020) + ; ("task.get_backtrace", 0.000018) + ; ("VTPM.get_uuid", 0.000017) + ; ("PBD.get_uuid", 0.000017) + ] + ) + +let median_token_cost = + let arr = Hashtbl.to_seq_values token_costs |> Array.of_seq in + let n = Array.length arr in + if n = 0 then + 0. + else ( + Array.sort Float.compare arr ; + if n mod 2 = 1 then + arr.(n / 2) + else + (arr.((n / 2) - 1) +. arr.(n / 2)) /. 2.0 + ) + +let get_token_cost name = + let amount = Hashtbl.find_opt token_costs name in + Option.value ~default:median_token_cost amount diff --git a/ocaml/xapi/xapi_rate_limit.mli b/ocaml/xapi/xapi_rate_limit.mli new file mode 100644 index 00000000000..0325492fe3e --- /dev/null +++ b/ocaml/xapi/xapi_rate_limit.mli @@ -0,0 +1,38 @@ +(* + * Copyright (C) Citrix Systems Inc. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +module Key = Rate_limit.Bucket_table.Key + +val submit_sync : client_id:Key.t -> callback:(unit -> 'a) -> float -> 'a + +val submit : client_id:Key.t -> callback:(unit -> unit) -> float -> unit + +val peek : client_id:Key.t -> float option + +val get_token_cost : string -> float + +val median_token_cost : float + +val create : + __context:Context.t + -> user_agent:string + -> host_ip:string + -> burst_size:float + -> fill_rate:float + -> [`Rate_limit] Ref.t + +val destroy : __context:Context.t -> self:[`Rate_limit] API.Ref.t -> unit + +val register : __context:Context.t -> unit +(** Create token buckets in the bucket table for each record in the database *) diff --git a/ocaml/xapi/xapi_session.mli b/ocaml/xapi/xapi_session.mli index 10baf03abc2..25ffa3fa366 100644 --- a/ocaml/xapi/xapi_session.mli +++ b/ocaml/xapi/xapi_session.mli @@ -15,8 +15,6 @@ * @group XenAPI functions *) -(** {2 (Fill in Title!)} *) - (* TODO: consider updating sm_exec.ml and removing login_no_password from this mli *) val login_no_password : __context:Context.t diff --git a/opam/rate-limit.opam b/opam/rate-limit.opam new file mode 100644 index 00000000000..e5114dc41fb --- /dev/null +++ b/opam/rate-limit.opam @@ -0,0 +1,31 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: "Simple token bucket-based rate-limiting" +maintainer: ["Xapi project maintainers"] +authors: ["xen-api@lists.xen.org"] +license: "LGPL-2.1-only WITH OCaml-LGPL-linking-exception" +homepage: "https://xapi-project.github.io/" +bug-reports: "https://github.com/xapi-project/xen-api/issues" +depends: [ + "dune" {>= "3.20"} + "ocaml" {>= "4.12"} + "xapi-log" + "xapi-stdext-unix" + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/xapi-project/xen-api.git" +x-maintenance-intent: ["(latest)"] diff --git a/quality-gate.sh b/quality-gate.sh index c82a98ea57f..88ef959dc9e 100755 --- a/quality-gate.sh +++ b/quality-gate.sh @@ -2,146 +2,145 @@ set -e -list-hd () { - N=244 - LIST_HD=$(git grep -r --count 'List.hd' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) - if [ "$LIST_HD" -eq "$N" ]; then - echo "OK counted $LIST_HD List.hd usages" - else - echo "ERROR expected $N List.hd usages, got $LIST_HD" 1>&2 - exit 1 - fi +list-hd() { + N=244 + LIST_HD=$(git grep -r --count 'List.hd' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) + if [ "$LIST_HD" -eq "$N" ]; then + echo "OK counted $LIST_HD List.hd usages" + else + echo "ERROR expected $N List.hd usages, got $LIST_HD" 1>&2 + exit 1 + fi } -verify-cert () { - N=13 - NONE=$(git grep -r --count 'verify_cert:None' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) - if [ "$NONE" -eq "$N" ]; then - echo "OK counted $NONE usages of verify_cert:None" - else - echo "ERROR expected $N verify_cert:None usages, got $NONE" 1>&2 - exit 1 - fi +verify-cert() { + N=13 + NONE=$(git grep -r --count 'verify_cert:None' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) + if [ "$NONE" -eq "$N" ]; then + echo "OK counted $NONE usages of verify_cert:None" + else + echo "ERROR expected $N verify_cert:None usages, got $NONE" 1>&2 + exit 1 + fi } -mli-files () { - N=457 - X="ocaml/tests" - X+="|ocaml/quicktest" - X+="|ocaml/message-switch/core_test" - # do not count ml files from the tests in ocaml/{tests/quicktest} - M=$(comm -23 <(git ls-files -- '**/*.ml' | sed 's/.ml$//' | sort) \ - <(git ls-files -- '**/*.mli' | sed 's/.mli$//' | sort) |\ +mli-files() { + N=458 + X="ocaml/tests" + X+="|ocaml/quicktest" + X+="|ocaml/message-switch/core_test" + # do not count ml files from the tests in ocaml/{tests/quicktest} + M=$(comm -23 <(git ls-files -- '**/*.ml' | sed 's/.ml$//' | sort) \ + <(git ls-files -- '**/*.mli' | sed 's/.mli$//' | sort) | grep -cvE "$X") - if [ "$M" -eq "$N" ]; then - echo "OK counted $M .ml files without an .mli" - else - echo "ERROR expected $N .ml files without .mlis, got $M."\ - "If you created some .ml files, they are probably missing corresponding .mli's" 1>&2 - exit 1 - fi + if [ "$M" -eq "$N" ]; then + echo "OK counted $M .ml files without an .mli" + else + echo "ERROR expected $N .ml files without .mlis, got $M." \ + "If you created some .ml files, they are probably missing corresponding .mli's" 1>&2 + exit 1 + fi } -structural-equality () { - N=7 - EQ=$(git grep -r --count ' == ' -- '**/*.ml' ':!ocaml/sdk-gen/**/*.ml' | cut -d ':' -f 2 | paste -sd+ - | bc) - if [ "$EQ" -eq "$N" ]; then - echo "OK counted $EQ usages of ' == '" - else - echo "ERROR expected $N usages of ' == ', got $EQ; use = rather than ==" 1>&2 - exit 1 - fi - - if git grep -r --count ' != ' -- '**/*.ml' ':!ocaml/sdk-gen/**/*.ml'; then - echo "ERROR expected no usages of ' != '; use <> rather than !=" 1>&2 - exit 1 - else - echo "OK found no usages of ' != '" - fi +structural-equality() { + N=7 + EQ=$(git grep -r --count ' == ' -- '**/*.ml' ':!ocaml/sdk-gen/**/*.ml' | cut -d ':' -f 2 | paste -sd+ - | bc) + if [ "$EQ" -eq "$N" ]; then + echo "OK counted $EQ usages of ' == '" + else + echo "ERROR expected $N usages of ' == ', got $EQ; use = rather than ==" 1>&2 + exit 1 + fi + + if git grep -r --count ' != ' -- '**/*.ml' ':!ocaml/sdk-gen/**/*.ml'; then + echo "ERROR expected no usages of ' != '; use <> rather than !=" 1>&2 + exit 1 + else + echo "OK found no usages of ' != '" + fi } -vtpm-unimplemented () { - N=2 - VTPM=$(git grep -r --count 'maybe_raise_vtpm_unimplemented' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) - if [ "$VTPM" -eq "$N" ]; then - echo "OK found $VTPM usages of vtpm unimplemented errors" - else - echo "ERROR expected $N usages of unimplemented vtpm functionality, got $VTPM." 1>&2 - exit 1 - fi +vtpm-unimplemented() { + N=2 + VTPM=$(git grep -r --count 'maybe_raise_vtpm_unimplemented' -- **/*.ml | cut -d ':' -f 2 | paste -sd+ - | bc) + if [ "$VTPM" -eq "$N" ]; then + echo "OK found $VTPM usages of vtpm unimplemented errors" + else + echo "ERROR expected $N usages of unimplemented vtpm functionality, got $VTPM." 1>&2 + exit 1 + fi } -vtpm-fields () { - A=$(git grep -hc "vTPM'_.*:" ocaml/xapi/importexport.ml) - B=$(git grep -hc ' field' ocaml/idl/datamodel_vtpm.ml) - case "$A/$B" in - 5/6) - echo "OK found $A/$B VTPM fields in importexport.ml datamodel_vtpm.ml" - ;; - *) - echo "ERROR have VTPM fields changed? $A/$B - check importexport.ml" 1>&2 - exit 1 - ;; - esac +vtpm-fields() { + A=$(git grep -hc "vTPM'_.*:" ocaml/xapi/importexport.ml) + B=$(git grep -hc ' field' ocaml/idl/datamodel_vtpm.ml) + case "$A/$B" in + 5/6) + echo "OK found $A/$B VTPM fields in importexport.ml datamodel_vtpm.ml" + ;; + *) + echo "ERROR have VTPM fields changed? $A/$B - check importexport.ml" 1>&2 + exit 1 + ;; + esac } -ocamlyacc () { - N=0 - OCAMLYACC=$(git grep -r -o --count "ocamlyacc" '**/dune' | wc -l) - if [ "$OCAMLYACC" -eq "$N" ]; then - echo "OK found $OCAMLYACC usages of ocamlyacc usages in dune files." - else - echo "ERROR expected $N usages of ocamlyacc in dune files, got $OCAMLYACC." 1>&2 - exit 1 - fi +ocamlyacc() { + N=0 + OCAMLYACC=$(git grep -r -o --count "ocamlyacc" '**/dune' | wc -l) + if [ "$OCAMLYACC" -eq "$N" ]; then + echo "OK found $OCAMLYACC usages of ocamlyacc usages in dune files." + else + echo "ERROR expected $N usages of ocamlyacc in dune files, got $OCAMLYACC." 1>&2 + exit 1 + fi } - -unixgetenv () { - N=0 - UNIXGETENV=$(git grep -P -r -o --count 'getenv(?!_opt)' -- **/*.ml | wc -l) - if [ "$UNIXGETENV" -eq "$N" ]; then - echo "OK found $UNIXGETENV usages of exception-raising Unix.getenv in OCaml files." - else - echo "ERROR expected $N usages of exception-raising Unix.getenv in OCaml files, got $UNIXGETENV" 1>&2 - exit 1 - fi +unixgetenv() { + N=0 + UNIXGETENV=$(git grep -P -r -o --count 'getenv(?!_opt)' -- **/*.ml | wc -l) + if [ "$UNIXGETENV" -eq "$N" ]; then + echo "OK found $UNIXGETENV usages of exception-raising Unix.getenv in OCaml files." + else + echo "ERROR expected $N usages of exception-raising Unix.getenv in OCaml files, got $UNIXGETENV" 1>&2 + exit 1 + fi } -hashtblfind () { - N=33 - # Looks for all .ml files except the ones using Core.Hashtbl.find, - # which already returns Option - HASHTBLFIND=$(git grep -P -r --count 'Hashtbl.find(?!_opt)' -- '**/*.ml' ':!ocaml/xapi-storage-script/main.ml' | cut -d ':' -f 2 | paste -sd+ - | bc) - if [ "$HASHTBLFIND" -eq "$N" ]; then - echo "OK counted $HASHTBLFIND usages of exception-raising Hashtbl.find" - else - echo "ERROR expected $N usages of exception-raising Hashtbl.find, got $HASHTBLFIND" 1>&2 - exit 1 - fi +hashtblfind() { + N=33 + # Looks for all .ml files except the ones using Core.Hashtbl.find, + # which already returns Option + HASHTBLFIND=$(git grep -P -r --count 'Hashtbl.find(?!_opt)' -- '**/*.ml' ':!ocaml/xapi-storage-script/main.ml' | cut -d ':' -f 2 | paste -sd+ - | bc) + if [ "$HASHTBLFIND" -eq "$N" ]; then + echo "OK counted $HASHTBLFIND usages of exception-raising Hashtbl.find" + else + echo "ERROR expected $N usages of exception-raising Hashtbl.find, got $HASHTBLFIND" 1>&2 + exit 1 + fi } -unnecessary-length () { - N=0 - local_grep () { - git grep -r -o --count "$1" -- '**/*.ml' | wc -l - } - UNNECESSARY_LENGTH=$(local_grep "List.length.*=+\s*0") - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "0\s*=+\s*List.length"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "List.length.*\s>\s*0"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "List.length.*\s<>\s*0"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "0\s*<\s*List.length"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "0\s*<>\s*List.length"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "List.length.*\s<\s*1"))) - UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH+$(local_grep "1\s*>\s*List.length"))) - if [ "$UNNECESSARY_LENGTH" -eq "$N" ]; then - echo "OK found $UNNECESSARY_LENGTH unnecessary usages of List.length in OCaml files." - else - echo "ERROR expected $N unnecessary usages of List.length in OCaml files, +unnecessary-length() { + N=0 + local_grep() { + git grep -r -o --count "$1" -- '**/*.ml' | wc -l + } + UNNECESSARY_LENGTH=$(local_grep "List.length.*=+\s*0") + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "0\s*=+\s*List.length"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "List.length.*\s>\s*0"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "List.length.*\s<>\s*0"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "0\s*<\s*List.length"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "0\s*<>\s*List.length"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "List.length.*\s<\s*1"))) + UNNECESSARY_LENGTH=$((UNNECESSARY_LENGTH + $(local_grep "1\s*>\s*List.length"))) + if [ "$UNNECESSARY_LENGTH" -eq "$N" ]; then + echo "OK found $UNNECESSARY_LENGTH unnecessary usages of List.length in OCaml files." + else + echo "ERROR expected $N unnecessary usages of List.length in OCaml files, got $UNNECESSARY_LENGTH. Use lst =/<> [] or match statements instead." 1>&2 - exit 1 - fi + exit 1 + fi } list-hd @@ -154,4 +153,3 @@ ocamlyacc unixgetenv hashtblfind unnecessary-length -