Skip to main content

Mastering EJB transaction control with . . . Lambda?

Posted by saintx on February 9, 2008 at 9:54 AM PST

My friend David Blevins showed me a great trick the other day. I was writing some in-container tests for an EJB/JPA application I'm working on, and needed some power tools to better control the scope of my transactions. After no small amount of pain, I became comfortable with the idea of increasing the similarity of the tests. That is to say, the tests needed to more carefully reflect the actual usage scenario that end users would encounter. Key to this was the concept of units of work. The mechanism he showed me for acheiving this very closely mimics Lambda notation from languages like LISP.

A unit of work is, more or less, the greatest amount of "work" that can be performed in a single transaction without adversely impacting performance by tying up database connections. If your units of work are too long, you chew up connection resources. If they are too small, you waste resources on creating lots of new small connections. So, the trick is to find the "sweet spot". How much work is enough?

As my software engineering professor Mats Heimdahl would sagely advise, "It depends."

The truth is that regardless of how big or small your transactions are, it's easy to get them wrong, and you need to do them correctly each time in order to guarantee that your application and tests perform and behave in a predictable manner. So, if you can abstract this element of your software, you are likely to be much better off.

To make a long story short, David showed me a pattern that I found indespensable for testing, because it mimics the lambda capabilities of languages like LISP.

Here's a look.

I have a base testing class that contains infrastructure for my tests. It looks something like this:

package org.eremite.corm.party.testutil;

import junit.framework.TestCase;
import org.eremite.corm.Archetype;

import javax.annotation.security.RunAs;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import javax.naming.Context;
import javax.naming.NamingException;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.Callable;

public abstract class CascadeTestBase extends TestCase {

    protected Context context = ContextFactory.reference();
    protected BeanManager mgr;

    public void setUp() throws Exception {
        super.setUp();
        mgr = (BeanManager) context.lookup(
            "BeanManagerLocal");
    }

    public void tearDown() throws Exception {
        super.tearDown();
    }

    protected static void traceln(Object... o) {
//        for(Object item : o) System.out.println(item);
    }

    protected boolean hasArchetypeWithID(Set s, long ID) {
        Iterator<Archetype> iter =
            (Iterator<Archetype>) s.iterator();
        while (iter.hasNext()) {
            Archetype a = iter.next();
            if (a != null && a.getID() == ID) return true;
        }
        return false;
    }

    protected Archetype getArchetypeWithID(Set s, long ID) {
        Iterator<Archetype> iter =
            (Iterator<Archetype>) s.iterator();
        while (iter.hasNext()) {
            Archetype a = iter.next();
            if (a != null && a.getID() == ID) return a;
        }
        return null;
    }

    protected void clear(String... tables) {
        for (String table : tables) mgr.clear(table);
    }

    public class Retriever<A extends Archetype> {
        private String table;
        private long ID;

        public Retriever(String table, long ID) {
            this.table = table;
            this.ID = ID;
        }

        public A fetch() {
            ExtendedBeanManager<A> mgr = null;

            try {
                mgr =
                    (ExtendedBeanManager<A>)
                        context.lookup(
                            "ExtendedBeanManagerLocal");
                           
            } catch (NamingException e) {
                e.printStackTrace();
            }

            return mgr.find(table, ID);
        }

        public A fetchWith(String column) {

            ExtendedBeanManager<A> mgr = null;

            try {
                mgr = (ExtendedBeanManager<A>)
                    context.lookup(
                        "ExtendedBeanManagerLocal");
            } catch (NamingException e) {
                e.printStackTrace();
            }

            return mgr.findAll(table, column, ID).
                iterator().next();
        }
    }

    // - - - - - - - - - - - - - - - - - - - - - - - - - - -
    //  Courtesy of David Blevins from the OpenEJB project
    //  These beans will get picked up if you have an empty
    //  src/test/resources/META-INF/ejb-jar.xml
    //  file which simply contains the text "<ejb-jar/>"
    // - - - - - - - - - - - - - - - - - - - - - - - - - - -

    public static interface Caller {
        public <V> V call(Callable<V> callable)
            throws Exception;
    }

    @Stateless
    @TransactionAttribute(
        TransactionAttributeType.REQUIRES_NEW)
    public static class TransactionBean implements Caller {
        public <V> V call(Callable<V> callable)
            throws Exception {
            return callable.call();
        }
    }

    @Stateless
    @RunAs("manager")
    public static class SecureBean implements Caller {
        public <V> V call(Callable<V> callable)
            throws Exception {
            return callable.call();
        }
    }

    @Stateless
    @RunAs("manager")
    @TransactionAttribute(
        TransactionAttributeType.REQUIRES_NEW)
    public static class SecureTransactedBean
        implements Caller {
        public <V> V call(Callable<V> callable)
            throws Exception {
            return callable.call();
        }
    }

    public void testTransaction() throws Exception {
        Caller transactional = (Caller)
            context.lookup("TransactionBeanLocal");
        for(Callable item : getCalls()) {
            transactional.call(item);
        }
    }

    // Test cases are required to implement this method.
    public abstract Callable[] getCalls();
}

Note, most importantly, the enterprise beans at the bottom. One of them, "TransactionBean", is the one I use in my test classes to actually perform my transactions.

Now, for the sake of brevity, I'm going to refrain from describing the data model that I'm testing. For now assume it's a commercial object-relational model that closely obeys the principle of least astonishment, and that every entity extends a managed superclass "Archetype". Also, without going into the implementation details, here is the interface definition for the BeanManager class used in the class above:

package org.eremite.corm.party.testutil;

import org.eremite.corm.Archetype;

import javax.ejb.Local;
import java.util.List;

@Local
public interface BeanManager<A extends Archetype> {

    public long persist(Archetype a);
    public void merge(Archetype a);
    public void remove(Archetype a);
    public long sizeOf(String table);
    public void clear(String table);
    public void flush();

    public List<A> query(String s);
    public A find(String table, long ID);
    public List<A> findAll(String table);
    public List<A> findAll(String table, String field);
    public List<A> findAll(
        String table, String field, long ID);
}

Given the two base utilities above, I can now do the following in my individual test classes in order to ensure the cascading behavior for collection-valued and field-valued entity references from any given entity:

package org.eremite.corm.party;

import org.eremite.corm.Archetype;
import org.eremite.corm.party.address.Address;
import org.eremite.corm.party.address.AssociatedAddress;
import org.eremite.corm.party.testutil.CascadeTestBase;

import java.util.Set;
import java.util.concurrent.Callable;

public class AddressCascadeTest extends CascadeTestBase {

    private Address address1, address2, address3;
    private Party party1;
    private AssociatedAddress aa1;
    private Archetype party2, party3, aa2, aa3;
    private long ID, partyID, aaID;
    Retriever<Address> bean;

    /**
     * Initialize objects and persist.
     * @return Callable block
     */
    private Callable unitOfWork_01() {
        return new Callable() {
            public Object call() {
                // initialize objects
                address1 = new Address();
                party1 = new Party();
                aa1 = new AssociatedAddress(
                    party1, address1);

                // ensure that the objects are wired up
                assertTrue(
                    address1.getParties().contains(aa1));
                   
                assertTrue(
                    party1.getAddresses().contains(aa1));
                   
                assertEquals(party1, aa1.getParty());
                assertEquals(address1, aa1.getAddress());

                mgr.persist(address1);
                mgr.flush();

                // assign the IDs
                ID = address1.getID();
                partyID = party1.getID();
                aaID = aa1.getID();

                // ensure the IDs are nonzero
                assertNotSame(0, ID);
                assertNotSame(0, partyID);
                assertNotSame(0, aaID);

                // Done.
                return null;
            }
        };
    }

    /**
     * Verify cascading persist and make updates.
     * @return Callable block
     */
    private Callable unitOfWork_02() {
        return new Callable() {
            public Object call() {
                // Verify base conditions
                assertEquals(1, mgr.sizeOf("Address"));
                assertEquals(1, mgr.sizeOf(
                    "AssociatedAddress"));
                   
                assertEquals(1, mgr.sizeOf("Party"));

                // Get new copy of stem object from DB
                bean = new Retriever<Address>(
                    "Address", ID);
                   
                address2 = bean.fetchWith("parties");
                assertNotNull(address2);

                // Get object from collection-valued ref
                Set<AssociatedAddress> assocs =
                    address2.getParties();
                   
                assertNotNull(assocs);
                assertEquals(1, assocs.size());
                assertTrue(hasArchetypeWithID(assocs, aaID));
                aa2 = getArchetypeWithID(assocs, aaID);
                assertNotNull(aa2);

                // Make some changes.
                address2.setName("address");
                aa2.setName("associated address");

                // Done.
                return null;
            }
        };
    }

    /**
     * Verify cascading updates and delete the stem object.
     * Verify the scope of the cascading delete.
     * Remove objects left over from the cascading delete.
     * Verify the objects were removed.
     * @return Callable block
     */
    private Callable unitOfWork_03() {
        return new Callable() {
            public Object call() {
                // Verify base conditions
                assertEquals(1, mgr.sizeOf("Address"));
                assertEquals(1, mgr.sizeOf(
                    "AssociatedAddress"));
                   
                assertEquals(1, mgr.sizeOf("Party"));

                // Get new copy of stem object from db
                address3 = bean.fetchWith("parties");
                assertNotNull(address3);

                // Get object from collection-valued ref
                Set<AssociatedAddress> assocs =
                   address3.getParties();
                assertNotNull(assocs);
                assertEquals(1, assocs.size());
                assertTrue(hasArchetypeWithID(assocs, aaID));
                aa3 = getArchetypeWithID(assocs, aaID);
                assertNotNull(aa3);

                // Verify earlier changes.
                assertEquals(
                    address2.getName(),
                    address3.getName());
                   
                assertEquals(
                    aa2.getName(),
                    aa3.getName());

                // Remove stem object
                mgr.remove(address3);
                mgr.flush();

                // Verify scope of cascading deletion
                assertEquals(0, mgr.sizeOf("Address"));
                assertEquals(0, mgr.sizeOf(
                    "AssociatedAddress"));
                   
                assertEquals(1, mgr.sizeOf("Party"));

                // Clean up leftovers
                mgr.remove(new Retriever<Party>(
                    "Party",
                    partyID).fetch());
                   
                mgr.flush();

                // Verify cleanup
                assertEquals(0, mgr.sizeOf("Party"));

                // Done.
                return null;
            }
        };
    }

    /**
     * Retrieve units of work for this test case
     * @return Callable[] units of work
     */
    public Callable[] getCalls() {
        return new Callable[]{
                unitOfWork_01(),
                unitOfWork_02(),
                unitOfWork_03()
        };
    }
}

Now, the three objects that encapsulate the units of work in the test case above obey the originally stated principle that "procedures operate in the environment in which they were created." We are sending each of the code blocks wrapped up in these anonymous inner classes to be executed by another object.

But, more importantly, we're sending their environments with them. This is the environment we assembled in this particular test case, which contains the object IDs, and the memory references.

By making CascadeTestBase class abstract, I force all of the subclasses to implement the getCalls() method, which is called by the lone test case in CascadeTestBase.

By using anonymous inner classes, I am able to generate unnamed implementations of the Callable interface and pass their contents and their enclosing environments around in a manner functionally identical to lambda calls in LISP!

Next time I'll show you how I implemented ALL of the CRUD tests for my first persistence module using this pattern in fewer than 200 lines of code.

Related Topics >>