Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions streamz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def __str__(self):

class APIRegisterMixin(object):

def _new_node(self, cls, args, kwargs):
""" Constructor for downstream nodes.

Examples
--------
To provide inheritance through nodes :

>>> class MyStream(Stream):
>>>
>>> def _new_node(self, cls, args, kwargs):
>>> if not issubclass(cls, MyStream):
>>> cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
>>> return cls(*args, **kwargs)
"""
return cls(*args, **kwargs)

@classmethod
def register_api(cls, modifier=identity, attribute_name=None):
""" Add callable to Stream API
Expand Down Expand Up @@ -158,6 +174,10 @@ def register_api(cls, modifier=identity, attribute_name=None):
def _(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if identity is not staticmethod and args:
self = args[0]
if isinstance(self, APIRegisterMixin):
return self._new_node(func, args, kwargs)
return func(*args, **kwargs)
name = attribute_name if attribute_name else func.__name__
setattr(cls, name, modifier(wrapped))
Expand Down
29 changes: 29 additions & 0 deletions streamz/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,35 @@ class foo(NewStream):
assert not hasattr(Stream(), 'foo')


def test_subclass_node():

def add(x) : return x + 1

class MyStream(Stream):
def _new_node(self, cls, args, kwargs):
if not issubclass(cls, MyStream):
cls = type(cls.__name__, (cls, MyStream), dict(cls.__dict__))
return cls(*args, **kwargs)

@MyStream.register_api()
class foo(sz.sinks.sink):
pass

stream = MyStream()
lst = list()

node = stream.map(add)
assert isinstance(node, sz.core.map)
assert isinstance(node, MyStream)

node = node.foo(lst.append)
assert isinstance(node, sz.sinks.sink)
assert isinstance(node, MyStream)

stream.emit(100)
assert lst == [ 101 ]


@gen_test()
def test_latest():
source = Stream(asynchronous=True)
Expand Down