Implemented Visitor pattern

This commit is contained in:
tylerlaberge
2016-08-08 21:58:51 -04:00
parent 259ae0407a
commit e16ca94de5
2 changed files with 154 additions and 0 deletions

View File

@@ -0,0 +1,50 @@
from abc import ABCMeta, abstractmethod
class Visitor(metaclass=ABCMeta):
"""
Abstract Visitor class as part of the Visitor Design Pattern.
"""
def visit(self, node, *args, **kwargs):
"""
Visit the visitor with some object.
@param node: An object to call a visitor method with.
@param args: Arguments to go with the visitor method call.
@param kwargs: Keyword arguments to go with the visitor method call.
@return: The return value of the method that was called for visiting object.
"""
method = None
for cls in node.__class__.__mro__:
method_name = 'visit_'+cls.__name__.lower()
method = getattr(self, method_name, None)
if method:
break
if not method:
method = self.generic_visit
return method(node, *args, **kwargs)
@abstractmethod
def generic_visit(self, node, *args, **kwargs):
"""
The method to call if no methods were found for a visiting object.
@param node: An object to call a visitor method with.
@param args: Arguments to go with the visitor method call.
@param kwargs: Keyword arguments to go with the visitor method call.
"""
class Visitee(object):
"""
A base class for objects that wish to be able to visit a Visitor class.
"""
def accept(self, visitor):
"""
Visit a visitor with this class instance.
@param visitor: The visitor to visit.
@type visitor: Visitor
"""
return visitor.visit(self)

View File

@@ -0,0 +1,104 @@
from unittest import TestCase
from pypatterns.behavioral.visitor import Visitor, Visitee
class VisitorTestCase(TestCase):
"""
Unit testing class for the Visitor class.
"""
def setUp(self):
"""
Initialize testing data.
"""
class Node(object):
pass
class A(Node):
pass
class B(Node):
pass
class C(A, B):
pass
class NodeVisitor(Visitor):
def generic_visit(self, node, *args, **kwargs):
return 'generic_visit ' + node.__class__.__name__
def visit_b(self, node, *args, **kwargs):
return 'visit_b ' + node.__class__.__name__
self.a = A()
self.b = B()
self.c = C()
self.node_visitor = NodeVisitor()
def test_generic_visit(self):
"""
Test that the generic_visit method is called.
@raise AssertionError: If the test fails.
"""
self.assertEquals('generic_visit A', self.node_visitor.visit(self.a))
def test_non_generic_visit(self):
"""
Test that a non_generic visit method is called.
@raise AssertionError: If the test fails.
"""
self.assertEquals('visit_b B', self.node_visitor.visit(self.b))
def test_inheritance_visit(self):
"""
Test that a parent visit method is called if a child does not have one.
@raise AssertionError: If the test fails.
"""
self.assertEquals('visit_b C', self.node_visitor.visit(self.c))
class VisiteeTestCase(TestCase):
"""
Unit testing class for the Visitee class.
"""
def setUp(self):
"""
Initialize testing data.
"""
class Node(object):
pass
class A(Node, Visitee):
pass
class B(A, Visitee):
pass
class C(B, Visitee):
pass
class NodeVisitor(Visitor):
def generic_visit(self, node, *args, **kwargs):
return 'generic_visit ' + node.__class__.__name__
def visit_b(self, node, *args, **kwargs):
return 'visit_b ' + node.__class__.__name__
self.a = A()
self.b = B()
self.c = C()
self.visitor = NodeVisitor()
def test_accept(self):
"""
Test the accept method.
@raise AssertionError: If the test fails.
"""
self.assertEquals('generic_visit A', self.a.accept(self.visitor))
self.assertEquals('visit_b B', self.b.accept(self.visitor))
self.assertEquals('visit_b C', self.c.accept(self.visitor))