You are given two integers, m and k, and a stream of integers. You are tasked to implement a data structure that calculates the MKAverage for the stream.
The MKAverage can be calculated using these steps:
- If the number of the elements in the stream is less than
myou should consider the MKAverage to be-1. Otherwise, copy the lastmelements of the stream to a separate container. - Remove the smallest
kelements and the largestkelements from the container. - Calculate the average value for the rest of the elements rounded down to the nearest integer.
Implement the MKAverage class:
MKAverage(int m, int k)Initializes the MKAverage object with an empty stream and the two integersmandk.void addElement(int num)Inserts a new elementnuminto the stream.int calculateMKAverage()Calculates and returns the MKAverage for the current stream rounded down to the nearest integer.
solution
brute force idea
Initially, the first idea that comes into my head, in terms of a brute force solution, is to maintain a window of length that stores the latest elements, and then for each query, we sort the array, remove the leftmost and rightmost elements, and then calculate the average.
The runtime of this approach is for addElement (when using a deque for the list of length ). The runtime for calculateMKAverage is , where we first sort and then calculate the sum of the middle elements.
optimization idea #1
However, if we want to improve from this, we probably want a way to maintain a sorted representation of our data, instead of having to sort every time we call calculateMKAverage.
To do this, we can maintain an unsorted list of elements called window, as well as a sorted list of elements called s. Everytime we call addElement, we update our sliding window window in time, and then update s in time ( to find the element to remove, and then to remove the element).
Then, our queries will just require us to sum up the middle values in s, taking time.
optimization idea #2
Instead of maintaining one large sorted array s of length , we can maintain three sorted arrays, called lows, mids, and highs. We store the lowest elements, the middle elements, and highest elements respectively in these sorted lists.
Then, when we call addElement, we remove the oldest element (falling out of the window) and add the newer element. These processes will take depending on where we’re adding to/removing from.
After removing a value, we shift the values down to maintain our invariance of sorting between the three lists. We guarantee that after _remove is run, we have one open space in self.highs.
Then, we run _add, inserting the number into the correct list, and shifting numbers right-ward if necessary.
Finally, we can still run queries in time.
a note on the use of
SortedListInstead of using normal Python lists, we use
SortedListhere as the implementation ofSortedListachieves amortized runtime on insertion/deletion, as lists are represented as a list of smaller lists internally.This improves the runtime of
addElementto . For details, visit here.
from sortedcontainers import SortedList
class MKAverage:
def __init__(self, m: int, k: int):
self.window = deque()
self.lows = SortedList()
self.mids = SortedList()
self.highs = SortedList()
self.m = m
self.k = k
def addElement(self, num: int) -> None:
# this is only run once
if len(self.window) + 1 == self.m:
self.window.append(num)
self._initializeLists()
elif len(self.window) < self.m:
self.window.append(num)
else:
# update sliding window and window sum
to_remove = self.window.popleft()
self.window.append(num)
self._remove(to_remove)
self._add(num)
def calculateMKAverage(self) -> int:
if len(self.window) < self.m:
return -1
return sum(self.mids) // (self.m - 2 * self.k)
def _initializeLists(self):
for i, num in enumerate(sorted(self.window)):
if i < self.k:
self.lows.add(num)
elif i >= self.m-self.k:
self.highs.add(num)
else:
self.mids.add(num)
def _add(self, num):
# assume groups always begin with one value
# missing in high when this function is called
if num >= self.mids[-1]:
self.highs.add(num)
elif num >= self.lows[-1]:
self.mids.add(num)
self.highs.add(self.mids.pop(-1))
else:
self.lows.add(num)
self.mids.add(self.lows.pop(-1))
self.highs.add(self.mids.pop(-1))
def _remove(self, num):
"""Remove num, then shift values left to fill"""
# assume that groups are full to begin with
if num <= self.lows[-1]:
self.lows.remove(num) # O(logn)
# shift left
self.lows.add(self.mids.pop(0))
self.mids.add(self.highs.pop(0))
elif num >= self.highs[0]:
self.highs.remove(num) # O(logn)
else:
self.mids.remove(num) # O(logn)
self.mids.add(self.highs.pop(0))optimization idea #3
Instead of summing up the entirety of mids at query time, we can maintain a counter called sum that keeps track of the sum of mids. This code is slightly harder to read, but improves query time from to .
from sortedcontainers import SortedList
class MKAverage:
def __init__(self, m: int, k: int):
self.window = deque()
self.sum = 0
self.lows = SortedList()
self.mids = SortedList()
self.highs = SortedList()
self.m = m
self.k = k
def addElement(self, num: int) -> None:
# this is only run once, just finished window
if len(self.window) + 1 == self.m:
self.window.append(num)
self._initializeLists()
# still building window
elif len(self.window) < self.m:
self.window.append(num)
else:
# update sliding window and window sum
to_remove = self.window.popleft()
self.window.append(num)
self._remove(to_remove)
self._add(num)
def calculateMKAverage(self) -> int:
if len(self.window) < self.m:
return -1
return self.sum // (self.m - 2 * self.k)
def _initializeLists(self):
for i, num in enumerate(sorted(self.window)):
if i < self.k:
self.lows.add(num)
elif i >= self.m-self.k:
self.highs.add(num)
else:
self.mids.add(num)
self.sum = sum(self.mids)
def _add(self, num):
# assume groups always begin with one value
# missing in high when this function is called
if num >= self.mids[-1]:
self.highs.add(num)
elif num >= self.lows[-1]:
self.mids.add(num)
self.sum += num
removed_mid = self.mids.pop(-1)
self.sum -= removed_mid
self.highs.add(removed_mid)
else:
self.lows.add(num)
removed_low = self.lows.pop(-1)
self.mids.add(removed_low)
self.sum += removed_low
removed_mid = self.mids.pop(-1)
self.sum -= removed_mid
self.highs.add(removed_mid)
def _remove(self, num):
"""Remove num, then shift values left to fill"""
# assume that groups are full to begin with
if num <= self.lows[-1]:
self.lows.remove(num) # O(logn)
# shift left
removed_mid = self.mids.pop(0)
self.lows.add(removed_mid)
self.sum -= removed_mid
removed_high = self.highs.pop(0)
self.mids.add(removed_high)
self.sum += removed_high
elif num >= self.highs[0]:
self.highs.remove(num) # O(logn)
else:
self.mids.remove(num) # O(logn)
self.sum -= num
removed_high = self.highs.pop(0)
self.mids.add(removed_high)
self.sum += removed_high