001/*
002 * HA-JDBC: High-Availability JDBC
003 * Copyright (C) 2012  Paul Ferraro
004 *
005 * This program is free software: you can redistribute it and/or modify
006 * it under the terms of the GNU Lesser General Public License as published by
007 * the Free Software Foundation, either version 3 of the License, or
008 * (at your option) any later version.
009 *
010 * This program is distributed in the hope that it will be useful,
011 * but WITHOUT ANY WARRANTY; without even the implied warranty of
012 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
013 * GNU Lesser General Public License for more details.
014 *
015 * You should have received a copy of the GNU Lesser General Public License
016 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
017 */
018package net.sf.hajdbc.sql.xa;
019
020import java.lang.reflect.Method;
021import java.sql.SQLException;
022import java.util.Arrays;
023import java.util.HashSet;
024import java.util.Set;
025import java.util.SortedMap;
026import java.util.concurrent.ConcurrentHashMap;
027import java.util.concurrent.ConcurrentMap;
028import java.util.concurrent.locks.Lock;
029
030import javax.sql.XAConnection;
031import javax.sql.XADataSource;
032import javax.transaction.xa.XAException;
033import javax.transaction.xa.XAResource;
034import javax.transaction.xa.Xid;
035
036import net.sf.hajdbc.Database;
037import net.sf.hajdbc.DatabaseCluster;
038import net.sf.hajdbc.durability.Durability;
039import net.sf.hajdbc.invocation.InvocationStrategies;
040import net.sf.hajdbc.invocation.InvocationStrategy;
041import net.sf.hajdbc.invocation.Invoker;
042import net.sf.hajdbc.sql.ChildInvocationHandler;
043import net.sf.hajdbc.sql.DurabilityPhaseRegistry;
044import net.sf.hajdbc.sql.ProxyFactory;
045import net.sf.hajdbc.util.StaticRegistry;
046import net.sf.hajdbc.util.reflect.Methods;
047
048/**
049 * @author Paul Ferraro
050 *
051 */
052@SuppressWarnings("nls")
053public class XAResourceInvocationHandler extends ChildInvocationHandler<XADataSource, XADataSourceDatabase, XAConnection, SQLException, XAResource, XAException, XAResourceProxyFactory>
054{
055        private static final Set<Method> driverReadMethodSet = Methods.findMethods(XAResource.class, "getTransactionTimeout", "isSameRM");
056        private static final Set<Method> databaseWriteMethodSet = Methods.findMethods(XAResource.class, "setTransactionTimeout");
057        private static final Set<Method> intraTransactionMethodSet = Methods.findMethods(XAResource.class, "prepare", "end", "recover");
058        private static final Method prepareMethod = Methods.getMethod(XAResource.class, "prepare", Xid.class);
059        private static final Method startMethod = Methods.getMethod(XAResource.class, "start", Xid.class, Integer.TYPE);
060        private static final Method commitMethod = Methods.getMethod(XAResource.class, "commit", Xid.class, Boolean.TYPE);
061        private static final Method rollbackMethod = Methods.getMethod(XAResource.class, "rollback", Xid.class);
062        private static final Method forgetMethod = Methods.getMethod(XAResource.class, "forget", Xid.class);
063        private static final Set<Method> endTransactionMethodSet = new HashSet<Method>(Arrays.asList(commitMethod, rollbackMethod, forgetMethod));
064        
065        private static final StaticRegistry<Method, Durability.Phase> phaseRegistry = new DurabilityPhaseRegistry(Arrays.asList(prepareMethod), Arrays.asList(commitMethod), Arrays.asList(rollbackMethod), Arrays.asList(forgetMethod));
066        
067        // Xids are global - so store in static variable
068        private static final ConcurrentMap<Xid, Lock> lockMap = new ConcurrentHashMap<Xid, Lock>();
069
070        public XAResourceInvocationHandler(XAResourceProxyFactory proxyFactory)
071        {
072                super(XAResource.class, proxyFactory, null);
073        }
074
075        /**
076         * @see net.sf.hajdbc.sql.AbstractInvocationHandler#getInvocationStrategy(java.lang.Object, java.lang.reflect.Method, java.lang.Object[])
077         */
078        @Override
079        protected InvocationStrategy getInvocationStrategy(XAResource resource, Method method, Object... parameters) throws XAException
080        {
081                if (driverReadMethodSet.contains(method))
082                {
083                        return InvocationStrategies.INVOKE_ON_ANY;
084                }
085                
086                if (databaseWriteMethodSet.contains(method))
087                {
088                        return InvocationStrategies.INVOKE_ON_ALL;
089                }
090                
091                boolean start = method.equals(startMethod);
092                boolean end = endTransactionMethodSet.contains(method);
093                
094                if (start || end || method.equals(prepareMethod) || intraTransactionMethodSet.contains(method))
095                {
096                        final InvocationStrategy strategy = end ? InvocationStrategies.END_TRANSACTION_INVOKE_ON_ALL : InvocationStrategies.TRANSACTION_INVOKE_ON_ALL;
097                        
098                        Xid xid = (Xid) parameters[0];
099                        
100                        DatabaseCluster<XADataSource, XADataSourceDatabase> cluster = this.getProxyFactory().getDatabaseCluster();
101                        
102                        if (start)
103                        {
104                                final Lock lock = cluster.getLockManager().readLock(null);
105                                
106                                // Lock may already exist if we're resuming a suspended transaction
107                                if (lockMap.putIfAbsent(xid, lock) == null)
108                                {
109                                        return new InvocationStrategy()
110                                        {
111                                                @Override
112                                                public <Z, D extends Database<Z>, T, R, E extends Exception> SortedMap<D, R> invoke(ProxyFactory<Z, D, T, E> proxy, Invoker<Z, D, T, R, E> invoker) throws E
113                                                {
114                                                        lock.lock();
115                                                        
116                                                        try
117                                                        {
118                                                                return strategy.invoke(proxy, invoker);
119                                                        }
120                                                        catch (Exception e)
121                                                        {
122                                                                lock.unlock();
123
124                                                                throw proxy.getExceptionFactory().createException(e);
125                                                        }
126                                                }
127                                        };
128                                }
129                        }
130                        
131                        Durability.Phase phase = phaseRegistry.get(method);
132                        
133                        if (phase != null)
134                        {
135                                final InvocationStrategy durabilityStrategy = cluster.getDurability().getInvocationStrategy(strategy, phase, xid);
136                                
137                                if (endTransactionMethodSet.contains(method))
138                                {
139                                        final Lock lock = lockMap.remove(xid);
140
141                                        return new InvocationStrategy()
142                                        {
143                                                @Override
144                                                public <Z, D extends Database<Z>, T, R, E extends Exception> SortedMap<D, R> invoke(ProxyFactory<Z, D, T, E> proxy, Invoker<Z, D, T, R, E> invoker) throws E
145                                                {
146                                                        try
147                                                        {
148                                                                return durabilityStrategy.invoke(proxy, invoker);
149                                                        }
150                                                        finally
151                                                        {
152                                                                if (lock != null)
153                                                                {
154                                                                        lock.unlock();
155                                                                }
156                                                        }
157                                                }
158                                        };
159                                }
160                                
161                                return durabilityStrategy;
162                        }
163                        
164                        return strategy;
165                }
166                
167                return super.getInvocationStrategy(resource, method, parameters);
168        }
169
170        /**
171         * {@inheritDoc}
172         * @see net.sf.hajdbc.sql.AbstractInvocationHandler#getInvoker(java.lang.Object, java.lang.reflect.Method, java.lang.Object[])
173         */
174        @Override
175        protected <R> Invoker<XADataSource, XADataSourceDatabase, XAResource, R, XAException> getInvoker(XAResource object, Method method, Object... parameters) throws XAException
176        {
177                Invoker<XADataSource, XADataSourceDatabase, XAResource, R, XAException> invoker = super.getInvoker(object, method, parameters);
178                
179                Durability.Phase phase = phaseRegistry.get(method);
180                
181                if (method.equals(prepareMethod) || endTransactionMethodSet.contains(method))
182                {
183                        return this.getProxyFactory().getDatabaseCluster().getDurability().getInvoker(invoker, phase, parameters[0], this.getProxyFactory().getExceptionFactory());
184                }
185                
186                return invoker;
187        }
188
189        @Override
190        protected <R> void postInvoke(Invoker<XADataSource, XADataSourceDatabase, XAResource, R, XAException> invoker, XAResource proxy, Method method, Object... parameters)
191        {
192                if (databaseWriteMethodSet.contains(method))
193                {
194                        this.getProxyFactory().record(invoker);
195                }
196        }
197}