1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  package com.github.advisedtesting.classloader;
24  
25  import java.io.DataInputStream;
26  import java.io.File;
27  import java.io.IOException;
28  import java.io.InputStream;
29  import java.lang.instrument.ClassFileTransformer;
30  import java.lang.instrument.IllegalClassFormatException;
31  import java.util.Arrays;
32  import java.util.HashMap;
33  import java.util.List;
34  import java.util.Map;
35  
36  public class EvictingClassLoader extends ClassLoader {
37  
38    
39    public static final String[] DEFAULT_EXCLUDED_PACKAGES =
40            new String[] {"java.", "javax.", "sun.", "oracle.", "com.sun.", "com.ibm.", "COM.ibm.",
41                "org.w3c.", "org.xml.", "org.dom4j.", "org.eclipse", "org.aspectj.", "net.sf.cglib",
42                "org.springframework.cglib", "org.apache.xerces.", "org.apache.commons.logging."};
43    
44    private final List<String> whiteList;
45  
46    private final ClassFileTransformer transformer;
47  
48    private final Map<String, String> classNameToError = new HashMap<>();
49    
50    public EvictingClassLoader(List<String> whiteList, ClassFileTransformer transformer, ClassLoader parent) {
51      super(parent);
52      this.whiteList = whiteList;
53      whiteList.addAll(Arrays.asList(DEFAULT_EXCLUDED_PACKAGES));
54      this.transformer = transformer;
55    }
56  
57    
58  
59    
60    private Class<?> getClass(String name) throws ClassNotFoundException {
61      String file = name.replace('.', File.separatorChar) + ".class";
62      byte[] bytes = null;
63      try {
64        bytes = loadClassData(file);
65        try {
66          transformer.transform(null, name, null, null, bytes);
67        } catch (ClassFormatError error) {
68          classNameToError.put(name, error.getMessage());
69          throw error;
70        } catch (IllegalClassFormatException icfe) {
71          throw new ClassNotFoundException(name, icfe);
72        }
73        Class<?> loaded = super.findLoadedClass(name);
74        if (loaded != null) {
75          return loaded;
76        } else {   
77          Class<?> cl = defineClass(name, bytes, 0, bytes.length);
78          resolveClass(cl);
79          return cl;
80        }
81      } catch (IOException ioe) {
82        ioe.printStackTrace();
83        return null;
84      }
85    }
86  
87    @Override
88    public Class<?> loadClass(String name) throws ClassNotFoundException {
89      boolean shouldLoad = true;
90      for (String prefix : whiteList) {
91        shouldLoad = shouldLoad && !name.startsWith(prefix);
92      }
93      if (shouldLoad) {
94        return getClass(name);
95      }
96      return super.loadClass(name);
97    }
98  
99    @Override
100   public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
101     boolean shouldLoad = true;
102     for (String prefix : whiteList) {
103       shouldLoad = shouldLoad && !name.startsWith(prefix);
104     }
105     if (shouldLoad) {
106       return getClass(name);
107     }
108     return super.loadClass(name, resolve);
109   }
110   
111   
112 
113 
114 
115 
116 
117 
118 
119 
120 
121   private byte[] loadClassData(String name) throws IOException {
122     InputStream stream = getClass().getClassLoader().getResourceAsStream(name);
123     int size = stream.available();
124     byte[] buff = new byte[size];
125     DataInputStream in = new DataInputStream(stream);
126     in.readFully(buff);
127     in.close();
128     return buff;
129   }
130   
131   
132 
133 
134 
135 
136 
137 
138 
139 
140 
141 
142   public String getError(String className) {
143     return classNameToError.get(className);
144   }
145 }