diff --git a/src/htm/algorithms/Connections.cpp b/src/htm/algorithms/Connections.cpp index fe5c53d8f9..b6361fd35f 100644 --- a/src/htm/algorithms/Connections.cpp +++ b/src/htm/algorithms/Connections.cpp @@ -763,49 +763,56 @@ std::ostream& operator<< (std::ostream& stream, const Connections& self) -bool Connections::operator==(const Connections &other) const { - if (cells_.size() != other.cells_.size()) - return false; +bool Connections::operator==(const Connections &o) const { + try { + NTA_CHECK (cells_.size() == o.cells_.size()) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size(); + NTA_CHECK (cells_ == o.cells_) << "Connections equals: cells_" << cells_.size() << " vs. " << o.cells_.size(); + + NTA_CHECK (segments_ == o.segments_ ) << "Connections equals: segments_"; + NTA_CHECK (destroyedSegments_ == o.destroyedSegments_ ) << "Connections equals: destroyedSegments_"; + + NTA_CHECK (synapses_ == o.synapses_ ) << "Connections equals: synapses_"; + NTA_CHECK (destroyedSynapses_ == o.destroyedSynapses_ ) << "Connections equals: destroyedSynapses_"; + + + //also check underlying datastructures (segments, and subsequently synapses). Can be time consuming. + //1.cells: + for(const auto cellD : cells_) { + //2.segments: + const auto& segments = cellD.segments; + for(const auto seg : segments) { + NTA_CHECK( dataForSegment(seg) == o.dataForSegment(seg) ) << "CellData equals: segmentData"; + //3.synapses: + const auto& synapses = dataForSegment(seg).synapses; + for(const auto syn : synapses) { + NTA_CHECK(dataForSynapse(syn) == o.dataForSynapse(syn) ) << "SegmentData equals: synapseData"; + } + } + } - if(iteration_ != other.iteration_) return false; - for (CellIdx i = 0; i < static_cast(cells_.size()); i++) { - const CellData &cellData = cells_[i]; - const CellData &otherCellData = other.cells_[i]; + NTA_CHECK (connectedThreshold_ == o.connectedThreshold_ ) << "Connections equals: connectedThreshold_"; + NTA_CHECK (iteration_ == o.iteration_ ) << "Connections equals: iteration_"; - if (cellData.segments.size() != otherCellData.segments.size()) { - return false; - } + NTA_CHECK(potentialSynapsesForPresynapticCell_ == o.potentialSynapsesForPresynapticCell_); + NTA_CHECK(connectedSynapsesForPresynapticCell_ == o.connectedSynapsesForPresynapticCell_); + NTA_CHECK(potentialSegmentsForPresynapticCell_ == o.potentialSegmentsForPresynapticCell_); + NTA_CHECK(connectedSegmentsForPresynapticCell_ == o.connectedSegmentsForPresynapticCell_); - for (SegmentIdx j = 0; j < static_cast(cellData.segments.size()); j++) { - const Segment segment = cellData.segments[j]; - const SegmentData &segmentData = segments_[segment]; - const Segment otherSegment = otherCellData.segments[j]; - const SegmentData &otherSegmentData = other.segments_[otherSegment]; + NTA_CHECK (nextSegmentOrdinal_ == o.nextSegmentOrdinal_ ) << "Connections equals: nextSegmentOrdinal_"; + NTA_CHECK (nextSynapseOrdinal_ == o.nextSynapseOrdinal_ ) << "Connections equals: nextSynapseOrdinal_"; - if (segmentData.synapses.size() != otherSegmentData.synapses.size() || - segmentData.cell != otherSegmentData.cell) { - return false; - } + NTA_CHECK (timeseries_ == o.timeseries_ ) << "Connections equals: timeseries_"; + NTA_CHECK (previousUpdates_ == o.previousUpdates_ ) << "Connections equals: previousUpdates_"; + NTA_CHECK (currentUpdates_ == o.currentUpdates_ ) << "Connections equals: currentUpdates_"; - for (SynapseIdx k = 0; k < static_cast(segmentData.synapses.size()); k++) { - const Synapse synapse = segmentData.synapses[k]; - const SynapseData &synapseData = synapses_[synapse]; - const Synapse otherSynapse = otherSegmentData.synapses[k]; - const SynapseData &otherSynapseData = other.synapses_[otherSynapse]; + NTA_CHECK (prunedSyns_ == o.prunedSyns_ ) << "Connections equals: prunedSyns_"; + NTA_CHECK (prunedSegs_ == o.prunedSegs_ ) << "Connections equals: prunedSegs_"; - if (synapseData.presynapticCell != otherSynapseData.presynapticCell || - synapseData.permanence != otherSynapseData.permanence) { - return false; - } - - // Two functionally identical instances may have different flatIdxs. - NTA_ASSERT(synapseData.segment == segment); - NTA_ASSERT(otherSynapseData.segment == otherSegment); - } - } + } catch(const htm::Exception& ex) { + std::cout << "Connection equals: differ! " << ex.what(); + return false; } - return true; } diff --git a/src/htm/algorithms/Connections.hpp b/src/htm/algorithms/Connections.hpp index f1455d6056..ee3fdf11b4 100644 --- a/src/htm/algorithms/Connections.hpp +++ b/src/htm/algorithms/Connections.hpp @@ -69,17 +69,40 @@ struct SynapseData: public Serializable { SynapseData() {} + //Serialization CerealAdapter; template void save_ar(Archive & ar) const { - ar(cereal::make_nvp("perm", permanence), - cereal::make_nvp("presyn", presynapticCell)); + ar(CEREAL_NVP(permanence), + CEREAL_NVP(presynapticCell), + CEREAL_NVP(segment), + CEREAL_NVP(presynapticMapIndex_), + CEREAL_NVP(id) + ); } template void load_ar(Archive & ar) { - ar( permanence, presynapticCell); + ar( permanence, presynapticCell, segment, presynapticMapIndex_, id); } + //operator== + bool operator==(const SynapseData& o) const { + try { + NTA_CHECK(presynapticCell == o.presynapticCell ) << "SynapseData equals: presynapticCell"; + NTA_CHECK(permanence == o.permanence ) << "SynapseData equals: permanence"; + NTA_CHECK(segment == o.segment ) << "SynapseData equals: segment"; + NTA_CHECK(presynapticMapIndex_ == o.presynapticMapIndex_ ) << "SynapseData equals: presynapticMapIndex_"; + NTA_CHECK(id == o.id ) << "SynapseData equals: id"; + } catch(const htm::Exception& ex) { + //NTA_WARN << "SynapseData equals: " << ex.what(); //Note: uncomment for debug, tells you + //where the diff is. It's perfectly OK for the "exception" to occur, as it just denotes + //that the data is NOT equal. + return false; + } + return true; + } + inline bool operator!=(const SynapseData& o) const { return !operator==(o); } + }; /** @@ -94,7 +117,7 @@ struct SynapseData: public Serializable { * @param cell * The cell that this segment is on. */ -struct SegmentData { +struct SegmentData: public Serializable { SegmentData(const CellIdx cell, Segment id, UInt32 lastUsed = 0) : cell(cell), numConnected(0), lastUsed(lastUsed), id(id) {} //default constructor std::vector synapses; @@ -102,6 +125,40 @@ struct SegmentData { SynapseIdx numConnected; //number of permanences from `synapses` that are >= synPermConnected, ie connected synapses UInt32 lastUsed = 0; //last used time (iteration). Used for segment pruning by "least recently used" (LRU) in `createSegment` Segment id; + + //Serialize + SegmentData() {}; //empty constructor for serialization, do not use + CerealAdapter; + template + void save_ar(Archive & ar) const { + ar(CEREAL_NVP(synapses), + CEREAL_NVP(cell), + CEREAL_NVP(numConnected), + CEREAL_NVP(lastUsed), + CEREAL_NVP(id) + ); + } + template + void load_ar(Archive & ar) { + ar( synapses, cell, numConnected, lastUsed, id); + } + + //equals op== + bool operator==(const SegmentData& o) const { + try { + NTA_CHECK(synapses == o.synapses) << "SegmentData equals: synapses"; + NTA_CHECK(cell == o.cell) << "SegmentData equals: cell"; + NTA_CHECK(numConnected == o.numConnected) << "SegmentData equals: numConnected"; + NTA_CHECK(lastUsed == o.lastUsed) << "SegmentData equals: lastUsed"; + NTA_CHECK(id == o.id) << "SegmentData equals: id"; + + } catch(const htm::Exception& ex) { + //NTA_WARN << "SegmentData equals: " << ex.what(); + return false; + } + return true; + } + inline bool operator!=(const SegmentData& o) const { return !operator==(o); } }; /** @@ -115,10 +172,35 @@ struct SegmentData { * Segments on this cell. * */ -struct CellData { +struct CellData : public Serializable { std::vector segments; + + //Serialization + CerealAdapter; + template + void save_ar(Archive & ar) const { + ar(CEREAL_NVP(segments) + ); + } + template + void load_ar(Archive & ar) { + ar( segments); + } + + //operator== + bool operator==(const CellData& o) const { + try { + NTA_CHECK( segments == o.segments ) << "CellData equals: segments"; + } catch(const htm::Exception& ex) { + //NTA_WARN << "CellData equals: " << ex.what(); + return false; + } + return true; + } + inline bool operator!=(const CellData& o) const { return !operator==(o); } }; + /** * A base class for Connections event handlers. * @@ -557,58 +639,58 @@ class Connections : public Serializable CerealAdapter; template void save_ar(Archive & ar) const { - // make this look like a queue of items to be sent. - // and a queue of sizes so we can distribute the - // correct number for each level when deserializing. - std::deque syndata; - std::deque sizes; - sizes.push_back(cells_.size()); - for (CellData cellData : cells_) { - const std::vector &segments = cellData.segments; - sizes.push_back(segments.size()); - for (Segment segment : segments) { - const SegmentData &segmentData = segments_[segment]; - const std::vector &synapses = segmentData.synapses; - sizes.push_back(synapses.size()); - for (Synapse synapse : synapses) { - const SynapseData &synapseData = synapses_[synapse]; - syndata.push_back(synapseData); - } - } - } ar(CEREAL_NVP(connectedThreshold_)); - //the following member must not be serialized (so is set to =0). - //That is because of we serialize only active segments & synapses, - //excluding the "destroyed", so those fields start empty. -//! ar(CEREAL_NVP(destroyedSegments_)); - ar(CEREAL_NVP(sizes)); - ar(CEREAL_NVP(syndata)); ar(CEREAL_NVP(iteration_)); + ar(CEREAL_NVP(cells_)); + ar(CEREAL_NVP(segments_)); + ar(CEREAL_NVP(synapses_)); + + ar(CEREAL_NVP(destroyedSynapses_)); + ar(CEREAL_NVP(destroyedSegments_)); + + ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_)); + ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_)); + ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_)); + ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_)); + + ar(CEREAL_NVP(nextSegmentOrdinal_)); + ar(CEREAL_NVP(nextSynapseOrdinal_)); + + ar(CEREAL_NVP(timeseries_)); + ar(CEREAL_NVP(previousUpdates_)); + ar(CEREAL_NVP(currentUpdates_)); + + ar(CEREAL_NVP(prunedSyns_)); + ar(CEREAL_NVP(prunedSegs_)); } template void load_ar(Archive & ar) { - std::deque sizes; - std::deque syndata; ar(CEREAL_NVP(connectedThreshold_)); - ar(CEREAL_NVP(sizes)); - ar(CEREAL_NVP(syndata)); - - CellIdx numCells = static_cast(sizes.front()); sizes.pop_front(); - initialize(numCells, connectedThreshold_); - for (UInt cell = 0; cell < numCells; cell++) { - size_t numSegments = sizes.front(); sizes.pop_front(); - for (SegmentIdx j = 0; j < static_cast(numSegments); j++) { - Segment segment = createSegment( cell ); - - size_t numSynapses = sizes.front(); sizes.pop_front(); - for (SynapseIdx k = 0; k < static_cast(numSynapses); k++) { - SynapseData& syn = syndata.front(); syndata.pop_front(); - createSynapse( segment, syn.presynapticCell, syn.permanence ); - } - } - } ar(CEREAL_NVP(iteration_)); + //!initialize(numCells, connectedThreshold_); //initialize Connections //Note: we actually don't call Connections + //initialize() as all the members are de/serialized. + ar(CEREAL_NVP(cells_)); + ar(CEREAL_NVP(segments_)); + ar(CEREAL_NVP(synapses_)); + + ar(CEREAL_NVP(destroyedSynapses_)); + ar(CEREAL_NVP(destroyedSegments_)); + + ar(CEREAL_NVP(potentialSynapsesForPresynapticCell_)); + ar(CEREAL_NVP(connectedSynapsesForPresynapticCell_)); + ar(CEREAL_NVP(potentialSegmentsForPresynapticCell_)); + ar(CEREAL_NVP(connectedSegmentsForPresynapticCell_)); + + ar(CEREAL_NVP(nextSegmentOrdinal_)); + ar(CEREAL_NVP(nextSynapseOrdinal_)); + + ar(CEREAL_NVP(timeseries_)); + ar(CEREAL_NVP(previousUpdates_)); + ar(CEREAL_NVP(currentUpdates_)); + + ar(CEREAL_NVP(prunedSyns_)); + ar(CEREAL_NVP(prunedSegs_)); } /** @@ -771,7 +853,7 @@ class Connections : public Serializable Synapse prunedSyns_ = 0; //how many synapses have been removed? Segment prunedSegs_ = 0; - //for listeners + //for listeners //TODO listeners are not serialized, nor included in equals == UInt32 nextEventToken_; std::map eventHandlers_; }; // end class Connections