Programming for fun and profit

Programming tutorials, problems, solutions. Always with code.

JUnit Custom Rules

JUnit comes with a bunch of rules, but sometimes writing own JUnit Custom Rules is useful. In the following tutorial we’re going to show to way to do that.

Basically there are two ways to write custom rules in JUnit:

  • extend ExternalResource class
    Overriding before and/or after methods are enough in most cases.
  • implement TestRule
    It’s for cases when you would like to evaluate Statement in own way.

The first approach (extending ExternalResource class) is simpler and more useful. In the following example we’re going to provide own implementation that measures performance of unit tests. To do that we’re going to do three things:

  1. Start counting test execution time in ExternalResource.before() method.
  2. Grab current test name from Description object.
  3. Calculate total test execution time in ExternalResource.after() method.

We’re using @Rule annotation, because we want the rule to be applied for each test. If we would like to execute it for whole test class we would use @ClassRule.

package com.farenda.junit;

import org.junit.*;
import org.junit.rules.ExternalResource;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

import java.util.Random;

import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.junit.Assert.assertEquals;

public class PerformanceLoggingRuleTest {

    public ExternalResource performanceLogger
            = new ExternalResource() {

        private long testStart;
        private String testName;

        protected void before() throws Throwable {
            // Do some pre-test action here - e.g. open connection.
            testStart = System.currentTimeMillis();

        // Don't override if you don't want to modify Statement.
        public Statement apply(Statement base, Description description) {
            // You can do some action here.
            testName = description.getMethodName();

            // Return new Statement() {...} if you want to evaluate differently
            return super.apply(base, description);

        protected void after() {
            // Do some post-test action - e.g. close connection.
            System.out.printf("%s executed in: %dms%n",
                    System.currentTimeMillis() - testStart);

    public void cryptoCruncher() throws InterruptedException {
        MILLISECONDS.sleep(new Random().nextInt(1000));
        assertEquals("Java", "Java");

    public void cpuStealer() throws InterruptedException {
        MILLISECONDS.sleep(new Random().nextInt(1000));
        assertEquals("JUnit".length(), 5);

Own implementation of TestRule are much less common, but sometimes you may need them. Let’s see how pros do it.

How pros do it

In the following SpringClassRule from Spring Framework we see implementation of org.junit.rules.TestRule interface to omit limitation of using only one JUnit Test Runner and org.junit.runners.model.Statement to cache test classes:

package org.springframework.test.context.junit4.rules;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Rule;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

import org.springframework.test.context.TestContextManager;
import org.springframework.test.context.junit4.statements.ProfileValueChecker;
import org.springframework.test.context.junit4.statements.RunAfterTestClassCallbacks;
import org.springframework.test.context.junit4.statements.RunBeforeTestClassCallbacks;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;

public class SpringClassRule implements TestRule {

    private static final Log logger = LogFactory.getLog(SpringClassRule.class);

     * Cache of {@code TestContextManagers} keyed by test class.
    private static final Map<Class<?>, TestContextManager> testContextManagerCache =
            new ConcurrentHashMap<Class<?>, TestContextManager>(64);

    // Used by RunAfterTestClassCallbacks
    private static final String MULTIPLE_FAILURE_EXCEPTION_CLASS_NAME = "org.junit.runners.model.MultipleFailureException";

    static {
        boolean junit4dot9Present = ClassUtils.isPresent(MULTIPLE_FAILURE_EXCEPTION_CLASS_NAME,
        if (!junit4dot9Present) {
            throw new IllegalStateException(String.format(
                "Failed to find class [%s]: SpringClassRule requires JUnit 4.9 or higher.",

     * Apply <em>class-level</em> features of the <em>Spring TestContext
     * Framework</em> to the supplied {@code base} statement.
     * <p>Specifically, this method retrieves the {@link TestContextManager}
     * used by this rule and its associated {@link SpringMethodRule} and
     * invokes the {@link TestContextManager#beforeTestClass() beforeTestClass()}
     * and {@link TestContextManager#afterTestClass() afterTestClass()} methods
     * on the {@code TestContextManager}.
     * <p>In addition, this method checks whether the test is enabled in
     * the current execution environment. This prevents classes with a
     * non-matching {@code @IfProfileValue} annotation from running altogether,
     * even skipping the execution of {@code beforeTestClass()} methods
     * in {@code TestExecutionListeners}.
     * @param base the base {@code Statement} that this rule should be applied to
     * @param description a {@code Description} of the current test execution
     * @return a statement that wraps the supplied {@code base} with class-level
     * features of the Spring TestContext Framework
     * @see #getTestContextManager
     * @see #withBeforeTestClassCallbacks
     * @see #withAfterTestClassCallbacks
     * @see #withProfileValueCheck
     * @see #withTestContextManagerCacheEviction
    public Statement apply(Statement base, Description description) {
        Class<?> testClass = description.getTestClass();
        if (logger.isDebugEnabled()) {
            logger.debug("Applying SpringClassRule to test class [" + testClass.getName() + "]");
        TestContextManager testContextManager = getTestContextManager(testClass);

        Statement statement = base;
        statement = withBeforeTestClassCallbacks(statement, testContextManager);
        statement = withAfterTestClassCallbacks(statement, testContextManager);
        statement = withProfileValueCheck(statement, testClass);
        statement = withTestContextManagerCacheEviction(statement, testClass);
        return statement;

    private Statement withBeforeTestClassCallbacks(Statement statement, TestContextManager testContextManager) {
        return new RunBeforeTestClassCallbacks(statement, testContextManager);

    private Statement withAfterTestClassCallbacks(Statement statement, TestContextManager testContextManager) {
        return new RunAfterTestClassCallbacks(statement, testContextManager);

    private Statement withProfileValueCheck(Statement statement, Class<?> testClass) {
        return new ProfileValueChecker(statement, testClass, null);

    private Statement withTestContextManagerCacheEviction(Statement statement, Class<?> testClass) {
        return new TestContextManagerCacheEvictor(statement, testClass);

     * Throw an {@link IllegalStateException} if the supplied {@code testClass}
     * does not declare a {@code public SpringMethodRule} field that is
     * annotated with {@code @Rule}.
    private static void validateSpringMethodRuleConfiguration(Class<?> testClass) {
        Field ruleField = null;

        for (Field field : testClass.getFields()) {
            int modifiers = field.getModifiers();
            if (!Modifier.isStatic(modifiers) && Modifier.isPublic(modifiers) &&
                    SpringMethodRule.class.isAssignableFrom(field.getType())) {
                ruleField = field;

        if (ruleField == null) {
            throw new IllegalStateException(String.format(
                    "Failed to find 'public SpringMethodRule' field in test class [%s]. " +
                    "Consult the javadoc for SpringClassRule for details.", testClass.getName()));

        if (!ruleField.isAnnotationPresent(Rule.class)) {
            throw new IllegalStateException(String.format(
                    "SpringMethodRule field [%s] must be annotated with JUnit's @Rule annotation. " +
                    "Consult the javadoc for SpringClassRule for details.", ruleField));

     * Get the {@link TestContextManager} associated with the supplied test class.
     * @param testClass the test class to be managed; never {@code null}
    static TestContextManager getTestContextManager(Class<?> testClass) {
        Assert.notNull(testClass, "testClass must not be null");
        synchronized (testContextManagerCache) {
            TestContextManager testContextManager = testContextManagerCache.get(testClass);
            if (testContextManager == null) {
                testContextManager = new TestContextManager(testClass);
                testContextManagerCache.put(testClass, testContextManager);
            return testContextManager;

    private static class TestContextManagerCacheEvictor extends Statement {

        private final Statement next;

        private final Class<?> testClass;

        TestContextManagerCacheEvictor(Statement next, Class<?> testClass) {
   = next;
            this.testClass = testClass;

        public void evaluate() throws Throwable {
            try {
            finally {

Code of masters is full of wisdom!

Share with the World!