Does heapq.merge work with Iterator classes?

67 Views Asked by At

I'm having an issue using the heapq.merge function with an iterator class. Once the merge function is down to one iterator it finishes with line yield from next.__self__ which in the case of an iterator class, restarts the iterator. The behaviour can be observed with this code:

import heapq

class SimpleIterator:

  a_list = [0,1,2,3,4,5]

  def __iter__(self):
     self.list_iter = iter(self.a_list)
     return self

  def __next__(self):
     return next(self.list_iter)**2

print(list(SimpleIterator()))
print(list(heapq.merge(SimpleIterator(), SimpleIterator())))

Expected output:
[0, 1, 4, 9, 16, 25]
[0, 0, 1, 1, 4, 4, 9, 9, 16, 16, 25, 25]

Actual output:
[0, 1, 4, 9, 16, 25]
[0, 0, 1, 1, 4, 4, 9, 9, 16, 16, 25, 25, 0, 1, 4, 9, 16, 25]

Am I setting up my iterator class incorrectly?
Is there a problem with the merge function?
Should the merge function not be used with an Iterator class?

I have a work around of using a generator function to wrap around the class but I was interested if there were other solutions

Edit: the work around:

def generator_wrapper(obj: SimpleIterator):
    yield from obj

print(list(generator_wrapper(SimpleIterator())))
print(list(heapq.merge(generator_wrapper(SimpleIterator()), generator_wrapper(SimpleIterator())))) 
1

There are 1 best solutions below

5
chepner On

TL;DR Iterators are not allowed to start returning values after they raise StopIteration. SimpleIterator does because it can reset the underlying list iterator after that iterator is exhausted.


heapq.merge, once it detects that all but one iterable has been exhausted, uses yield from next.__self__ to iterate over the remaining iterable. This triggers a second call to the second iterable's __iter__ method, in which you unconditionally create a new iterator over the list.

One solution would be to only create self.list_iter if it does not already exist.

class SimpleIterator:

    a_list = [0,1,2,3,4,5]
    def __init__(self):
        self.list_iter = None

    def __iter__(self):
        if self.list_iter is None:
            self.list_iter = iter(self.a_list)
        return self

    def __next__(self):
        return next(self.list_iter)**2

This ensures the existing list iterator is used to finish the merge.

From the documentation for __next__:

Once an iterator’s __next__() method raises StopIteration, it must continue to do so on subsequent calls. Implementations that do not obey this property are deemed broken.

By resetting self.list_iter after its __next__ has raised StopIteration the first time, your implementation of __iter__ does not satisfy this requirement.

This is only a problem because __iter__ does not return a new iterator for an instance of SimpleIterator, but returns the existing object itself. The problem also goes away if you separate the iterator from the iterable. For example,

class SimpleIterable:
    a_list = [0,1,2,3,4,5]
    
    def __iter__(self):
        return SimpleIterator(self)


class SimpleIterator:
    def __init__(self, iterable):
        self.list_iterator = iter(iterable.a_list)

    # All iterators should return self
    def __iter__(self):
        return self

    def __next__(self):
        return next(self.list_iterator)