Coding District Coding District - 3 months ago 18
C++ Question

Visitor Pattern for AST

I'm trying to use the visitor pattern to perform operations for the AST of my compiler but I can't seem to figure out an implementation that will work properly.

AST classes excerpt:

class AstNode
{
public:
AstNode() {}
};

class Program : public AstNode
{
public:
std::vector<std::shared_ptr<Class>> classes;

Program(const std::vector<std::shared_ptr<Class>>&);
void accept(AstNodeVisitor& visitor) const { visitor.visit(*this); }
};

class Expression : public AstNode
{
public:
Expression() {}
};

class Method : public Feature
{
public:
Symbol name;
Symbol return_type;
std::vector<std::shared_ptr<Formal>> params;
std::shared_ptr<Expression> body;

Method(const Symbol&, const Symbol&, const std::vector<std::shared_ptr<Formal>>&,
const std::shared_ptr<Expression>&);
feature_type get_type() const;
};

class Class : public AstNode
{
public:
Symbol name;
Symbol parent;
Symbol filename;
std::vector<std::shared_ptr<Feature>> features;

Class(const Symbol&, const Symbol&, const Symbol&,
const std::vector<std::shared_ptr<Feature>>&);
};

class Assign : public Expression
{
public:
Symbol name;
std::shared_ptr<Expression> rhs;

Assign(const Symbol&, const std::shared_ptr<Expression>&);
};


Visitor (partial implementation):

class AstNodeVisitor
{
public:
virtual void visit(const Program&) = 0;
virtual void visit(const Class&) = 0;
virtual void visit(const Attribute&) = 0;
virtual void visit(const Formal&) = 0;
virtual void visit(const Method&) = 0;
};

class AstNodePrintVisitor : public AstNodeVisitor
{
private:
size_t depth;

public:
void visit(const Program& node) {
for (auto cs : node.classes)
visit(*cs);
}

void visit(const Class&);
void visit(const Attribute&);
void visit(const Formal&);
void visit(const Method&);
};


How I'm using it:

AstNodePrintVisitor print;
ast_root->accept(print); // ast_root is a shared_ptr<Program>


The issue:

The Method Node contains a body member of type Expression - which is a base class. How will I visit it?

I thought maybe I could simply write an accept method for each AST node and do the traversal there instead. (ie. instead of calling visit() in the visitor, call accept() in the visitable then call visit(*this) so the calls will be polymorphic and the right visit() method of the visitor gets called.

However, if I do this, I will have no option for traversing top-down (operation then recurse) or bottom-up (recurse then operation) since I have to choose only one. By this I mean a PrintVisitor for example will need a top-down traversal of the AST but a TypeCheck will need a bottom-up approach.

Is there a way around this? Or am I over-engineering things? Right now I think the fastest way is to just implement the methods in the nodes themselves.

Answer

Let's begin with a minor correction to the craft of a Visitor:

void visit(const Program& node) { 
    for (auto cs : node.classes)
        visit(*cs);
}

the call visit(*cs) should be cs->accept(*this) to allow for virtual dispatch, in the generic case.


And now to the main question: the control of traversal order.

A visitor can only really visit a tree in a depth first way, breadth first may be implemented but is quirky in a single visit method (you basically need to separate visitation from iterations on children).

On the other hand, even in a depth first traversal, you may chose whether to act on the parent either before or after having visited the children.

The typical way to do so would be to provide an intermediate layer between the pure base class and the real actor, for example:

class RecursiveAstNodeVisitor: public AstNodeVisitor 
{
public:
    // returns whether or not to stop recursion
    virtual bool actBefore(Program const&) { return false; }
    virtual void actAfter(Program const&) {}

    virtual bool actBefore(Class const&) { return false; }
    virtual void actAfter(Class const&) {}

    // ... You get the idea


    virtual void visit(Program const& p) {
        if (actBefore(p)) { return; }

        for (auto c: p.classes) {
            c->accept(*this);
        }

        actAfter(p);
    }

    // ... You get the idea
};

The overrider is free to act either before or after the recursion occurs... and of course may act on both!

class PrintAstNodeVisitor: public RecursiveAstNodeVisitor {
public:
     PrintAstNodeVisitor(std::ostream& out): _out(out), _prefix() {}

     virtual bool actBefore(Program const& p) {
         _out << "{\n";
         _out << "  \"type\": \"Program\",\n";
         _out << "  \"name\": \" << p.name << "\",\n";
         _out << "  \"classes\": [\n";

         _prefix = "    ";

         return false;
      }

      virtual void actAfter(Program const& p) {
         _out << "  ]\n";
         _out << "}\n";
      }

      virtual bool actBefore(Class const& c) {
         _out << _prefix << "{\n";
         _out << _prefix << "  \"type\": \"Class\",\n";
         // ...
      }

private:
    std::ostream& _out;
    std::string _prefix;
};
Comments