diff --git a/cmd/pineconesim/simulator/adversary/drop_packets.go b/cmd/pineconesim/simulator/adversary/drop_packets.go index 32f1e8a1..e9fa99a3 100644 --- a/cmd/pineconesim/simulator/adversary/drop_packets.go +++ b/cmd/pineconesim/simulator/adversary/drop_packets.go @@ -119,8 +119,8 @@ func NewAdversaryRouter(log *log.Logger, sk ed25519.PrivateKey, debug bool) *Adv return adversary } -func (a *AdversaryRouter) Subscribe(ch chan events.Event) { - a.rtr.Subscribe(ch) +func (a *AdversaryRouter) Subscribe(ch chan events.Event) router.NodeState { + return a.rtr.Subscribe(ch) } func (a *AdversaryRouter) PublicKey() types.PublicKey { diff --git a/cmd/pineconesim/simulator/nodes.go b/cmd/pineconesim/simulator/nodes.go index 5bf43a0b..3abf06d6 100644 --- a/cmd/pineconesim/simulator/nodes.go +++ b/cmd/pineconesim/simulator/nodes.go @@ -127,14 +127,14 @@ func (sim *Simulator) StartNodeEventHandler(t string, nodeType APINodeType) { ch := make(chan events.Event) handler := eventHandler{node: t, ch: ch} quit := make(chan bool) - go handler.Run(quit, sim) - sim.nodes[t].Subscribe(ch) + nodeState := sim.nodes[t].Subscribe(ch) sim.nodeRunnerChannelsMutex.Lock() sim.nodeRunnerChannels[t] = append(sim.nodeRunnerChannels[t], quit) sim.nodeRunnerChannelsMutex.Unlock() - phony.Block(sim.State, func() { sim.State._addNode(t, sim.nodes[t].PublicKey().String(), nodeType) }) + phony.Block(sim.State, func() { sim.State._addNode(t, sim.nodes[t].PublicKey().String(), nodeType, nodeState) }) + go handler.Run(quit, sim) } func (sim *Simulator) RemoveNode(node string) { diff --git a/cmd/pineconesim/simulator/router.go b/cmd/pineconesim/simulator/router.go index 8cc7713d..96f36331 100644 --- a/cmd/pineconesim/simulator/router.go +++ b/cmd/pineconesim/simulator/router.go @@ -31,7 +31,7 @@ import ( type SimRouter interface { PublicKey() types.PublicKey Connect(conn net.Conn, options ...router.ConnectionOption) (types.SwitchPortID, error) - Subscribe(ch chan events.Event) + Subscribe(ch chan events.Event) router.NodeState Ping(ctx context.Context, a net.Addr) (uint16, time.Duration, error) Coords() types.Coordinates ConfigureFilterDefaults(rates adversary.DropRates) @@ -44,8 +44,8 @@ type DefaultRouter struct { pings sync.Map // types.PublicKey -> chan struct{} } -func (r *DefaultRouter) Subscribe(ch chan events.Event) { - r.rtr.Subscribe(ch) +func (r *DefaultRouter) Subscribe(ch chan events.Event) router.NodeState { + return r.rtr.Subscribe(ch) } func (r *DefaultRouter) PublicKey() types.PublicKey { diff --git a/cmd/pineconesim/simulator/simulator.go b/cmd/pineconesim/simulator/simulator.go index fd2f1e27..309cfd8e 100644 --- a/cmd/pineconesim/simulator/simulator.go +++ b/cmd/pineconesim/simulator/simulator.go @@ -382,6 +382,8 @@ func (sim *Simulator) handleTreeRootAnnUpdate(node string, root string, sequence rootName := "" if peerNode, err := sim.State.GetNodeName(root); err == nil { rootName = peerNode + } else { + log.Fatalf("Cannot convert %s to root for %s", root, node) } sim.State.Act(nil, func() { sim.State._updateTreeRootAnnouncement(node, rootName, sequence, time, coords) }) } diff --git a/cmd/pineconesim/simulator/state.go b/cmd/pineconesim/simulator/state.go index 5ece7b8b..85b295ab 100644 --- a/cmd/pineconesim/simulator/state.go +++ b/cmd/pineconesim/simulator/state.go @@ -19,6 +19,7 @@ import ( "reflect" "github.com/Arceliar/phony" + "github.com/matrix-org/pinecone/router" ) type RootAnnouncement struct { @@ -174,13 +175,7 @@ func (s *StateAccessor) GetNodeName(peerID string) (string, error) { node := "" err := fmt.Errorf("Provided peerID is not associated with a known node") - phony.Block(s, func() { - for k, v := range s._state.Nodes { - if v.PeerID == peerID { - node, err = k, nil - } - } - }) + phony.Block(s, func() { node, err = s._getNodeName(peerID) }) return node, err } @@ -195,8 +190,54 @@ func (s *StateAccessor) GetNodeCoords(name string) []uint64 { return coords } -func (s *StateAccessor) _addNode(name string, peerID string, nodeType APINodeType) { +func (s *StateAccessor) _getNodeName(peerID string) (string, error) { + node := "" + err := fmt.Errorf("Provided peerID is not associated with a known node") + + for k, v := range s._state.Nodes { + if v.PeerID == peerID { + node, err = k, nil + } + } + + return node, err +} + +func (s *StateAccessor) _addNode(name string, peerID string, nodeType APINodeType, nodeState router.NodeState) { s._state.Nodes[name] = NewNodeState(peerID, nodeType) + if peernode, err := s._getNodeName(nodeState.Parent); err == nil { + s._state.Nodes[name].Parent = peernode + } + connections := map[int]string{} + for i, node := range nodeState.Connections { + if i == 0 { + // NOTE : Skip connection on port 0 since it is the loopback port + continue + } + if peernode, err := s._getNodeName(node); err == nil { + connections[i] = peernode + } + } + s._state.Nodes[name].Connections = connections + s._state.Nodes[name].Coords = nodeState.Coords + root := "" + if peernode, err := s._getNodeName(nodeState.Announcement.RootPublicKey.String()); err == nil { + root = peernode + } + announcement := RootAnnouncement{ + Root: root, + Sequence: uint64(nodeState.Announcement.RootSequence), + Time: nodeState.AnnouncementTime, + } + s._state.Nodes[name].Announcement = announcement + if peernode, err := s._getNodeName(nodeState.AscendingPeer); err == nil { + s._state.Nodes[name].AscendingPeer = peernode + } + s._state.Nodes[name].AscendingPathID = nodeState.AscendingPathID + if peernode, err := s._getNodeName(nodeState.DescendingPeer); err == nil { + s._state.Nodes[name].DescendingPeer = peernode + } + s._state.Nodes[name].DescendingPathID = nodeState.DescendingPathID s._publish(NodeAdded{Node: name, PublicKey: peerID, NodeType: int(nodeType)}) } diff --git a/router/api.go b/router/api.go index 039e40d3..2d700ac5 100644 --- a/router/api.go +++ b/router/api.go @@ -38,11 +38,64 @@ type PeerInfo struct { Zone string } +type NodeState struct { + PeerID string + Connections map[int]string + Parent string + Coords []uint64 + Announcement types.SwitchAnnouncement + AnnouncementTime uint64 + AscendingPeer string + AscendingPathID string + DescendingPeer string + DescendingPathID string +} + // Subscribe registers a subscriber to this node's events -func (r *Router) Subscribe(ch chan<- events.Event) { +func (r *Router) Subscribe(ch chan<- events.Event) NodeState { + var stateCopy NodeState phony.Block(r, func() { r._subscribers[ch] = &phony.Inbox{} + stateCopy.PeerID = r.public.String() + connections := map[int]string{} + for _, p := range r.state._peers { + if p == nil { + continue + } + connections[int(p.port)] = p.public.String() + } + stateCopy.Connections = connections + parent := "" + if r.state._parent != nil { + parent = r.state._parent.public.String() + } + stateCopy.Parent = parent + coords := []uint64{} + for _, coord := range r.Coords() { + coords = append(coords, uint64(coord)) + } + stateCopy.Coords = coords + announcement := r.state._rootAnnouncement() + stateCopy.Announcement = announcement.SwitchAnnouncement + stateCopy.AnnouncementTime = uint64(announcement.receiveTime.UnixNano()) + asc := "" + ascPath := "" + if r.state._ascending != nil { + asc = r.state._ascending.PublicKey.String() + ascPath = hex.EncodeToString(r.state._ascending.PathID[:]) + } + stateCopy.AscendingPeer = asc + stateCopy.AscendingPathID = ascPath + desc := "" + descPath := "" + if r.state._descending != nil { + desc = r.state._descending.PublicKey.String() + descPath = hex.EncodeToString(r.state._descending.PathID[:]) + } + stateCopy.DescendingPeer = desc + stateCopy.DescendingPathID = descPath }) + return stateCopy } func (r *Router) Coords() types.Coordinates { diff --git a/router/tests/basic_integration_test.go b/router/tests/basic_integration_test.go new file mode 100644 index 00000000..5adad0b4 --- /dev/null +++ b/router/tests/basic_integration_test.go @@ -0,0 +1,219 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "log" + "sort" + "testing" + "time" + + "github.com/matrix-org/pinecone/cmd/pineconesim/simulator" +) + +const SettlingTime time.Duration = time.Second * 2 +const TestTimeout time.Duration = time.Second * 5 + +type TreeValidationState struct { + roots map[string]string + correctRoot string +} + +func TestNodesAgreeOnCorrectTreeRoot(t *testing.T) { + t.Parallel() + // Arrange + scenario := NewScenarioFixture(t) + nodes := []string{"Alice", "Bob", "Charlie"} + scenario.AddStandardNodes(nodes) + + // Act + scenario.AddPeerConnections([]NodePair{NodePair{"Alice", "Bob"}, NodePair{"Bob", "Charlie"}}) + + // Assert + stateCapture := func(state simulator.State) interface{} { + lastRoots := make(map[string]string) + for _, node := range nodes { + lastRoots[node] = state.Nodes[node].Announcement.Root + } + + nodesByKey := make(byKey, 0, len(state.Nodes)) + for key, value := range state.Nodes { + nodesByKey = append(nodesByKey, Node{key, value.PeerID}) + } + sort.Sort(nodesByKey) + + correctRoot := nodesByKey[len(nodesByKey)-1].name + + return TreeValidationState{roots: lastRoots, correctRoot: correctRoot} + } + + nodesAgreeOnCorrectTreeRoot := func(prevState interface{}, event simulator.SimEvent) (interface{}, EventHandlerResult) { + switch state := prevState.(type) { + case TreeValidationState: + action := DoNothing + switch e := event.(type) { + case simulator.TreeRootAnnUpdate: + if state.roots[e.Node] != e.Root { + log.Printf("Root changed for %s to %s", e.Node, e.Root) + state.roots[e.Node] = e.Root + } else { + log.Printf("Got duplicate root info for %s", e.Node) + break + } + + nodesAgreeOnRoot := true + rootSample := "" + for _, node := range state.roots { + rootSample = node + for _, comparison := range state.roots { + if node != comparison { + nodesAgreeOnRoot = false + break + } + } + } + + if nodesAgreeOnRoot && state.correctRoot == rootSample { + log.Println("Start settling for tree test") + action = StartSettlingTimer + } else { + log.Println("Stop settling for tree test") + action = StopSettlingTimer + } + } + + return state, action + } + + return prevState, StopSettlingTimer + } + + scenario.Validate(stateCapture, nodesAgreeOnCorrectTreeRoot, SettlingTime, TestTimeout) +} + +type SnakeNeighbours struct { + asc string + desc string +} + +type SnakeValidationState struct { + snake map[string]SnakeNeighbours + correctSnake map[string]SnakeNeighbours +} + +type Node struct { + name string + key string +} + +type byKey []Node + +func (l byKey) Len() int { + return len(l) +} + +func (l byKey) Less(i, j int) bool { + return l[i].key < l[j].key +} + +func (l byKey) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} + +func TestNodesAgreeOnCorrectSnakeFormation(t *testing.T) { + t.Parallel() + // Arrange + scenario := NewScenarioFixture(t) + nodes := []string{"Alice", "Bob", "Charlie"} + scenario.AddStandardNodes(nodes) + + // Act + scenario.AddPeerConnections([]NodePair{NodePair{"Alice", "Bob"}, NodePair{"Bob", "Charlie"}}) + + // Assert + stateCapture := func(state simulator.State) interface{} { + snakeNeighbours := make(map[string]SnakeNeighbours) + for _, node := range nodes { + asc := state.Nodes[node].AscendingPeer + desc := state.Nodes[node].DescendingPeer + snakeNeighbours[node] = SnakeNeighbours{asc: asc, desc: desc} + } + + nodesByKey := make(byKey, 0, len(state.Nodes)) + for key, value := range state.Nodes { + nodesByKey = append(nodesByKey, Node{key, value.PeerID}) + } + sort.Sort(nodesByKey) + + correctSnake := make(map[string]SnakeNeighbours) + lowest := SnakeNeighbours{asc: nodesByKey[1].name, desc: ""} + middle := SnakeNeighbours{asc: nodesByKey[2].name, desc: nodesByKey[0].name} + highest := SnakeNeighbours{asc: "", desc: nodesByKey[1].name} + correctSnake[nodesByKey[0].name] = lowest + correctSnake[nodesByKey[1].name] = middle + correctSnake[nodesByKey[2].name] = highest + + return SnakeValidationState{snakeNeighbours, correctSnake} + } + + nodesAgreeOnCorrectSnakeFormation := func(prevState interface{}, event simulator.SimEvent) (interface{}, EventHandlerResult) { + switch state := prevState.(type) { + case SnakeValidationState: + isSnakeCorrect := func() bool { + snakeIsCorrect := true + for key, val := range state.snake { + if val.asc != state.correctSnake[key].asc || val.desc != state.correctSnake[key].desc { + snakeIsCorrect = false + break + } + } + return snakeIsCorrect + } + + snakeWasCorrect := isSnakeCorrect() + + action := DoNothing + updateReceived := false + switch e := event.(type) { + case simulator.SnakeAscUpdate: + updateReceived = true + if node, ok := state.snake[e.Node]; ok { + node.asc = e.Peer + state.snake[e.Node] = node + } + case simulator.SnakeDescUpdate: + updateReceived = true + if node, ok := state.snake[e.Node]; ok { + node.desc = e.Peer + state.snake[e.Node] = node + } + } + + if updateReceived { + if isSnakeCorrect() && !snakeWasCorrect { + action = StartSettlingTimer + } else { + action = StopSettlingTimer + } + } + + return state, action + } + + return prevState, StopSettlingTimer + } + + scenario.Validate(stateCapture, nodesAgreeOnCorrectSnakeFormation, SettlingTime, TestTimeout) +} diff --git a/router/tests/scenario_fixture.go b/router/tests/scenario_fixture.go new file mode 100644 index 00000000..33718de7 --- /dev/null +++ b/router/tests/scenario_fixture.go @@ -0,0 +1,151 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/matrix-org/pinecone/cmd/pineconesim/simulator" +) + +type EventHandlerResult int + +const ( + DoNothing EventHandlerResult = iota + StopSettlingTimer + StartSettlingTimer +) + +type InitialStateCapture func(state simulator.State) interface{} +type EventHandler func(prevState interface{}, event simulator.SimEvent) (interface{}, EventHandlerResult) + +type NodePair struct { + A string + B string +} + +type ScenarioFixture struct { + t *testing.T + log *log.Logger + sim *simulator.Simulator +} + +func NewScenarioFixture(t *testing.T) ScenarioFixture { + log := log.New(os.Stdout, "\u001b[36m***\u001b[0m ", 0) + useSockets := false + runPing := false + acceptCommands := true + simulator := simulator.NewSimulator(log, useSockets, runPing, acceptCommands) + + return ScenarioFixture{ + t: t, + log: log, + sim: simulator, + } +} + +func (s *ScenarioFixture) AddStandardNodes(nodes []string) { + for _, node := range nodes { + cmd := simulator.AddNode{ + Node: node, + NodeType: simulator.DefaultNode, + } + cmd.Run(s.log, s.sim) + } +} + +func (s *ScenarioFixture) AddAdversaryNodes(nodes []string) { + for _, node := range nodes { + cmd := simulator.AddNode{ + Node: node, + NodeType: simulator.GeneralAdversaryNode, + } + cmd.Run(s.log, s.sim) + } +} + +func (s *ScenarioFixture) AddPeerConnections(conns []NodePair) { + for _, pair := range conns { + cmd := simulator.AddPeer{ + Node: pair.A, + Peer: pair.B, + } + cmd.Run(s.log, s.sim) + } +} + +func (s *ScenarioFixture) SubscribeToSimState(ch chan simulator.SimEvent) simulator.State { + return s.sim.State.Subscribe(ch) +} + +func (s *ScenarioFixture) Validate(initialState InitialStateCapture, eventHandler EventHandler, settlingTime time.Duration, timeout time.Duration) { + testTimeout := time.NewTimer(timeout) + defer testTimeout.Stop() + + quit := make(chan bool) + output := make(chan string) + go assertState(s, initialState, eventHandler, quit, output, settlingTime) + + failed := false + + select { + case <-testTimeout.C: + failed = true + quit <- true + case <-output: + log.Println("Test passed") + } + + if failed { + state := <-output + s.t.Fatalf("Test timeout reached. Current State: %s", state) + } +} + +func assertState(scenario *ScenarioFixture, stateCapture InitialStateCapture, eventHandler EventHandler, quit chan bool, output chan string, settlingTime time.Duration) { + settlingTimer := time.NewTimer(settlingTime) + settlingTimer.Stop() + + simUpdates := make(chan simulator.SimEvent) + state := scenario.SubscribeToSimState(simUpdates) + + prevState := stateCapture(state) + + for { + select { + case <-quit: + output <- fmt.Sprintf("%+v", prevState) + return + case <-settlingTimer.C: + output <- "PASS" + case event := <-simUpdates: + newState, newResult := eventHandler(prevState, event) + switch newResult { + case StartSettlingTimer: + log.Println("Starting settling timer") + settlingTimer.Reset(settlingTime) + case StopSettlingTimer: + log.Println("Stopping settling timer") + settlingTimer.Stop() + } + + prevState = newState + } + } +}