# You are a data scientist working on flood mitigation in a mountainous province.
# Streams flow down from the mountain onto a high desert plain, where they combine
# and eventually merge into one large river. On some tributary streams, there is a
# water gauge that measures the flow rate in cubic feet per second (cfs).
#
# In the desert plain, there are a number of small towns along the rivers that
# want to predict how much water will flow through their riverbeds after a storm.
#
# Your task is to implement two methods to enable this use case.
# 1: `update(flow_datum)` - called whenever a gauge logs a new reading
# 2: `expected_runoff(town, timestamp)` - computes the total expected flow `town`
#    expects to see at `timestamp`, given the gauge readings and the average time
#    it takes the water to flow through each river segment
#
# As input, you will receive a tree representing the river system and the towns on each river
# For your convenience, the we've sketched out a map of the rivers and towns below
#
# Clarifications/Assumptions:
# 1: if `update` is called with a river_id `r` and timestamp `ts`, all future calls to `update`
#    with river_id `r` will have timestamp `ts' > ts`
# 2: assume that the total flow of water is conserved through the system - ie no water gets lost
#    or added to the system after it passes through a gauge (no net evaporation/rainfall/seepage)
# 3: you can assume that there will be many more calls to `update` than to `expected_runoff` when
#    reasoning about runtime tradeoffs between the two methods

from dataclasses import dataclass, field
from datetime import datetime, timedelta

@dataclass
class FlowDatum:
    river_id: int
    ts: datetime
    flow: int

@dataclass
class RiverNode:
    id: int
    tributaries: "list[RiverNode]" # list of tributary rivers that flow into this one
    float_time: timedelta # how long it takes water (on average) to flow from the start to end of this river segment
    towns: dict[str, timedelta] # map of town name -> how long it takes water to reach the town from the start of the river segment
    has_flow_gauge: bool = field(default=False) # if true, float_time measures time from gauge to river segment end

class MountainRunoff:

    def __init__(self, river_system: RiverNode):
        self._rivers = river_system
        self._samples: dict[int, list[FlowDatum]] = {}
        self._town_to_river: dict[str, RiverNode] = {}

        def populate_town_to_river(river: RiverNode):
            """Populate _town_to_river recursively."""
            self._town_to_river.update({
                town: river for town in river.towns})
            for tributary in river.tributaries:
                populate_town_to_river(tributary)

        populate_town_to_river(river_system)

    def update(self, flow_datum: FlowDatum):
        self._samples.setdefault(flow_datum.river_id, []).append(flow_datum)

    def expected_runoff(self, town: str, timestamp: datetime) -> int | None:
        # Returns None if we cannot compute the expected flow for the requested city
        if (river := self._town_to_river.get(town, None)) is None:
            return None

        if town not in river.towns:
            return None

        def query_flow(river: RiverNode, ts: datetime) -> int | None:
            """Recursively check the flow of the river at the given time."""
            total_flow: int = 0

            if river.has_flow_gauge:
                samples = tuple(filter(
                    lambda datum: datum.ts <= ts,
                    reversed(self._samples.get(river.id, []))))

                if samples:
                    #print(f"Flow A through {river.id} is {samples[0].flow}.")
                    return samples[0].flow

                # Otherwise, fall through to the aggregate check.

            if not river.tributaries:
                # No tributaries from which to aggregate.
                return None

            # Aggregate over tributaries.
            for tributary in river.tributaries:
                if (t_flow := query_flow(tributary, ts - tributary.float_time)) is None:
                    return None
                else:
                    #print(f"Flow B through {tributary.id} is {t_flow}.")
                    total_flow += t_flow

            #print(f"Flow C through {river.id} is {total_flow}.")
            return total_flow

        return query_flow(river, timestamp - river.towns[town])

# ------------- SAMPLE DATA ------------- #

river_system = RiverNode(
    id=0,
    towns={"abilene": timedelta(hours=5)},
    float_time=timedelta(hours=10),
    tributaries=[
        RiverNode(
            id=1,
            towns={"bigfoot": timedelta(hours=4)},
            float_time=timedelta(hours=6),
            tributaries=[
                RiverNode(id=3, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=5)),
                RiverNode(id=4, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=6)),
                RiverNode(id=5, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=5)),
            ]
        ),
        RiverNode(
            id=2,
            towns={"cadbury": timedelta(hours=5), "davenport": timedelta(hours=2)},
            float_time=timedelta(hours=6),
            tributaries=[
                RiverNode(id=6, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=5)),
                RiverNode(
                    id=7,
                    towns={"eagle": timedelta(hours=2)},
                    float_time=timedelta(hours=3),
                    tributaries=[
                        RiverNode(id=8, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=2)),
                        RiverNode(id=9, towns={}, tributaries=[], has_flow_gauge=1, float_time=timedelta(hours=2)),
                    ]
                )
            ]
        )
    ]
)

flow_data = [
    FlowDatum(3, datetime(year=2026, month=2, day=1, hour=5), 5000),
    FlowDatum(4, datetime(year=2026, month=2, day=1, hour=5), 4000),
    FlowDatum(5, datetime(year=2026, month=2, day=1, hour=6), 3000),
    FlowDatum(6, datetime(year=2026, month=2, day=1, hour=6), 5000),
    FlowDatum(8, datetime(year=2026, month=2, day=1, hour=6), 6000),
    FlowDatum(9, datetime(year=2026, month=2, day=1, hour=7), 2000),
    FlowDatum(3, datetime(year=2026, month=2, day=2, hour=5), 4000),
    FlowDatum(4, datetime(year=2026, month=2, day=2, hour=7), 3000),
    FlowDatum(5, datetime(year=2026, month=2, day=2, hour=11), 1000),
    FlowDatum(8, datetime(year=2026, month=2, day=2, hour=12), 3000),
    FlowDatum(9, datetime(year=2026, month=2, day=2, hour=12), 3000),
    FlowDatum(6, datetime(year=2026, month=2, day=2, hour=14), 500),
    FlowDatum(3, datetime(year=2026, month=2, day=3, hour=2), 7000),
    FlowDatum(5, datetime(year=2026, month=2, day=3, hour=5), 6000),
    FlowDatum(6, datetime(year=2026, month=2, day=3, hour=7), 4000),
    FlowDatum(4, datetime(year=2026, month=2, day=3, hour=6), 8000),
    FlowDatum(8, datetime(year=2026, month=2, day=3, hour=8), 9000),
    FlowDatum(9, datetime(year=2026, month=2, day=3, hour=10), 3000),
    FlowDatum(3, datetime(year=2026, month=2, day=4, hour=5), 2000),
    FlowDatum(4, datetime(year=2026, month=2, day=4, hour=5), 500),
    FlowDatum(5, datetime(year=2026, month=2, day=4, hour=6), 1000),
    FlowDatum(6, datetime(year=2026, month=2, day=4, hour=8), 1500),
    FlowDatum(8, datetime(year=2026, month=2, day=4, hour=6), 2000),
    FlowDatum(9, datetime(year=2026, month=2, day=4, hour=7), 1000),
]

# ------------- EXAMPLE ------------- #

# assert mountain_runoff.expected_runoff("eagle", datetime(year=2026, month=2, day=4, hour=10, minute=30)) == 5000
#
# The town of Eagle is downstream of gauges on rivers 8 and 9. It takes 4 hours from water to flow from the gauges to
# Eagle. Since Eagle wants to know its flow at day 4, hour 10.5, we need to look at the gauges' flows as of day 4, hour 6.5
#
# Gauge 8's latest entry is at day 4, hour 6, reading 2000 cfs. Gauge 9's latest entry is at day 3, hour 10 (note we cannot
# use the day 4 reading from gauge 9 because it is in the future of the time we want to sample), reading 3000 cfs. Therefore,
# the town of Eagle estimates it'll see a flow of 5000 cfs on 2/4/26 at 10:30 am

# ------------- BASIC TESTS ------------- #

# Here, we test your implementation when there is only one data point per gauge:
# Gauge 3 -> 5000
# Gauge 4 -> 4000
# Gauge 5 -> 3000
# Guage 6 -> 5000
# Gauge 8 -> 6000
# Gauge 9 -> 2000

mountain_runoff = MountainRunoff(river_system)
for flow_datum in flow_data[:6]:
    mountain_runoff.update(flow_datum)

assert mountain_runoff.expected_runoff("eagle", datetime(year=2026, month=2, day=4, hour=10, minute=30)) == 8000 # 8 + 9
assert mountain_runoff.expected_runoff("cadbury", datetime(year=2026, month=2, day=4, hour=9, minute=30)) == 13000 # 6 + 8 + 9
assert mountain_runoff.expected_runoff("bigfoot", datetime(year=2026, month=2, day=4, hour=19, minute=30)) == 12000 # 3 + 4 + 5
assert mountain_runoff.expected_runoff("abilene", datetime(year=2026, month=2, day=4, hour=21, minute=30)) == 25000 # all

# ------------- FULL TESTS ------------- #

mountain_runoff = MountainRunoff(river_system)

for flow_datum in flow_data:
    mountain_runoff.update(flow_datum)

assert mountain_runoff.expected_runoff("ficticious", datetime(year=2026, month=2, day=4, hour=6, minute=30)) is None
assert mountain_runoff.expected_runoff("eagle", datetime(year=2026, month=2, day=4, hour=10, minute=30)) == 5000
assert mountain_runoff.expected_runoff("davenport", datetime(year=2026, month=2, day=1, hour=13, minute=30)) is None
assert mountain_runoff.expected_runoff("cadbury", datetime(year=2026, month=2, day=2, hour=9, minute=30)) == 13000
assert mountain_runoff.expected_runoff("bigfoot", datetime(year=2026, month=2, day=2, hour=19, minute=30)) == 10000
assert mountain_runoff.expected_runoff("abilene", datetime(year=2026, month=2, day=3, hour=21, minute=30)) == 22500
assert mountain_runoff.expected_runoff("abilene", datetime(year=2026, month=2, day=3, hour=20, minute=30)) == 17500

for i in range(10):
    # Davenport is 3 hours upstream of Cadbury, so we expect their expected flows to match at any time when offset by 3 hrs
    assert (
        mountain_runoff.expected_runoff("cadbury", datetime(year=2026, month=2, day=1, hour=10, minute=30) + (i * timedelta(hours=8)))
        == mountain_runoff.expected_runoff("davenport", datetime(year=2026, month=2, day=1, hour=10, minute=30) + (i * timedelta(hours=8)) - timedelta(hours=3))
    )

