View Javadoc
1   /*
2    *  Licensed to the Apache Software Foundation (ASF) under one
3    *  or more contributor license agreements.  See the NOTICE file
4    *  distributed with this work for additional information
5    *  regarding copyright ownership.  The ASF licenses this file
6    *  to you under the Apache License, Version 2.0 (the
7    *  "License"); you may not use this file except in compliance
8    *  with the License.  You may obtain a copy of the License at
9    *
10   *    http://www.apache.org/licenses/LICENSE-2.0
11   *
12   *  Unless required by applicable law or agreed to in writing,
13   *  software distributed under the License is distributed on an
14   *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15   *  KIND, either express or implied.  See the License for the
16   *  specific language governing permissions and limitations
17   *  under the License.
18   *
19   */
20  package org.apache.mina.filter.codec.demux;
21  
22  import java.util.Map;
23  import java.util.Set;
24  import java.util.concurrent.ConcurrentHashMap;
25  
26  import org.apache.mina.core.session.AttributeKey;
27  import org.apache.mina.core.session.IoSession;
28  import org.apache.mina.core.session.UnknownMessageTypeException;
29  import org.apache.mina.filter.codec.ProtocolEncoder;
30  import org.apache.mina.filter.codec.ProtocolEncoderOutput;
31  import org.apache.mina.util.CopyOnWriteMap;
32  import org.apache.mina.util.IdentityHashSet;
33  
34  /**
35   * A composite {@link ProtocolEncoder} that demultiplexes incoming message
36   * encoding requests into an appropriate {@link MessageEncoder}.
37   *
38   * <h2>Disposing resources acquired by {@link MessageEncoder}</h2>
39   * <p>
40   * Override {@link #dispose(IoSession)} method. Please don't forget to call
41   * <tt>super.dispose()</tt>.
42   *
43   * @author <a href="http://mina.apache.org">Apache MINA Project</a>
44   *
45   * @see MessageEncoderFactory
46   * @see MessageEncoder
47   */
48  public class DemuxingProtocolEncoder implements ProtocolEncoder {
49  
50      private final AttributeKey STATE = new AttributeKey(getClass(), "state");
51  
52      @SuppressWarnings("rawtypes")
53      private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>();
54  
55      private static final Class<?>[] EMPTY_PARAMS = new Class[0];
56  
57      public DemuxingProtocolEncoder() {
58          // Do nothing
59      }
60  
61      @SuppressWarnings({ "rawtypes", "unchecked" })
62      public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
63          if (encoderClass == null) {
64              throw new IllegalArgumentException("encoderClass");
65          }
66  
67          try {
68              encoderClass.getConstructor(EMPTY_PARAMS);
69          } catch (NoSuchMethodException e) {
70              throw new IllegalArgumentException("The specified class doesn't have a public default constructor.");
71          }
72  
73          boolean registered = false;
74          if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
75              addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
76              registered = true;
77          }
78  
79          if (!registered) {
80              throw new IllegalArgumentException("Unregisterable type: " + encoderClass);
81          }
82      }
83  
84      @SuppressWarnings({ "unchecked", "rawtypes" })
85      public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
86          addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
87      }
88  
89      public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
90          if (messageType == null) {
91              throw new IllegalArgumentException("messageType");
92          }
93  
94          if (factory == null) {
95              throw new IllegalArgumentException("factory");
96          }
97  
98          synchronized (type2encoderFactory) {
99              if (type2encoderFactory.containsKey(messageType)) {
100                 throw new IllegalStateException("The specified message type (" + messageType.getName()
101                         + ") is registered already.");
102             }
103 
104             type2encoderFactory.put(messageType, factory);
105         }
106     }
107 
108     @SuppressWarnings("rawtypes")
109     public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
110         for (Class<?> messageType : messageTypes) {
111             addMessageEncoder(messageType, encoderClass);
112         }
113     }
114 
115     public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
116         for (Class<? extends T> messageType : messageTypes) {
117             addMessageEncoder(messageType, encoder);
118         }
119     }
120 
121     public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes,
122             MessageEncoderFactory<? super T> factory) {
123         for (Class<? extends T> messageType : messageTypes) {
124             addMessageEncoder(messageType, factory);
125         }
126     }
127 
128     /**
129      * {@inheritDoc}
130      */
131     public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
132         State state = getState(session);
133         MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
134         if (encoder != null) {
135             encoder.encode(session, message, out);
136         } else {
137             throw new UnknownMessageTypeException("No message encoder found for message: " + message);
138         }
139     }
140 
141     protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
142         return findEncoder(state, type, null);
143     }
144 
145     @SuppressWarnings("unchecked")
146     private MessageEncoder<Object> findEncoder(State state, Class<?> type, Set<Class<?>> triedClasses) {
147         @SuppressWarnings("rawtypes")
148         MessageEncoder encoder = null;
149 
150         if (triedClasses != null && triedClasses.contains(type)) {
151             return null;
152         }
153 
154         /*
155          * Try the cache first.
156          */
157         encoder = state.findEncoderCache.get(type);
158 
159         if (encoder != null) {
160             return encoder;
161         }
162 
163         /*
164          * Try the registered encoders for an immediate match.
165          */
166         encoder = state.type2encoder.get(type);
167 
168         if (encoder == null) {
169             /*
170              * No immediate match could be found. Search the type's interfaces.
171              */
172 
173             if (triedClasses == null) {
174                 triedClasses = new IdentityHashSet<Class<?>>();
175             }
176 
177             triedClasses.add(type);
178 
179             Class<?>[] interfaces = type.getInterfaces();
180 
181             for (Class<?> element : interfaces) {
182                 encoder = findEncoder(state, element, triedClasses);
183 
184                 if (encoder != null) {
185                     break;
186                 }
187             }
188         }
189 
190         if (encoder == null) {
191             /*
192              * No match in type's interfaces could be found. Search the
193              * superclass.
194              */
195 
196             Class<?> superclass = type.getSuperclass();
197 
198             if (superclass != null) {
199                 encoder = findEncoder(state, superclass);
200             }
201         }
202 
203         /*
204          * Make sure the encoder is added to the cache. By updating the cache
205          * here all the types (superclasses and interfaces) in the path which
206          * led to a match will be cached along with the immediate message type.
207          */
208         if (encoder != null) {
209             state.findEncoderCache.put(type, encoder);
210             MessageEncoder<Object> tmpEncoder = state.findEncoderCache.putIfAbsent(type, encoder);
211 
212             if (tmpEncoder != null) {
213                 encoder = tmpEncoder;
214             }
215         }
216 
217         return encoder;
218     }
219 
220     /**
221      * {@inheritDoc}
222      */
223     public void dispose(IoSession session) throws Exception {
224         session.removeAttribute(STATE);
225     }
226 
227     private State getState(IoSession session) throws Exception {
228         State state = (State) session.getAttribute(STATE);
229         if (state == null) {
230             state = new State();
231             State oldState = (State) session.setAttributeIfAbsent(STATE, state);
232             if (oldState != null) {
233                 state = oldState;
234             }
235         }
236         return state;
237     }
238 
239     private class State {
240         @SuppressWarnings("rawtypes")
241         private final ConcurrentHashMap<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<Class<?>, MessageEncoder>();
242 
243         @SuppressWarnings("rawtypes")
244         private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<Class<?>, MessageEncoder>();
245 
246         @SuppressWarnings("rawtypes")
247         private State() throws Exception {
248             for (Map.Entry<Class<?>, MessageEncoderFactory> e : type2encoderFactory.entrySet()) {
249                 type2encoder.put(e.getKey(), e.getValue().getEncoder());
250             }
251         }
252     }
253 
254     private static class SingletonMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
255         private final MessageEncoder<T> encoder;
256 
257         private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
258             if (encoder == null) {
259                 throw new IllegalArgumentException("encoder");
260             }
261             this.encoder = encoder;
262         }
263 
264         public MessageEncoder<T> getEncoder() {
265             return encoder;
266         }
267     }
268 
269     private static class DefaultConstructorMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
270         private final Class<MessageEncoder<T>> encoderClass;
271 
272         private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
273             if (encoderClass == null) {
274                 throw new IllegalArgumentException("encoderClass");
275             }
276 
277             if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
278                 throw new IllegalArgumentException("encoderClass is not assignable to MessageEncoder");
279             }
280             this.encoderClass = encoderClass;
281         }
282 
283         public MessageEncoder<T> getEncoder() throws Exception {
284             return encoderClass.newInstance();
285         }
286     }
287 }