1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23 package com.github.advisedtesting.classloader;
24
25 import java.io.DataInputStream;
26 import java.io.File;
27 import java.io.IOException;
28 import java.io.InputStream;
29 import java.lang.instrument.ClassFileTransformer;
30 import java.lang.instrument.IllegalClassFormatException;
31 import java.util.Arrays;
32 import java.util.HashMap;
33 import java.util.List;
34 import java.util.Map;
35
36 public class EvictingClassLoader extends ClassLoader {
37
38
39 public static final String[] DEFAULT_EXCLUDED_PACKAGES =
40 new String[] {"java.", "javax.", "sun.", "oracle.", "com.sun.", "com.ibm.", "COM.ibm.",
41 "org.w3c.", "org.xml.", "org.dom4j.", "org.eclipse", "org.aspectj.", "net.sf.cglib",
42 "org.springframework.cglib", "org.apache.xerces.", "org.apache.commons.logging."};
43
44 private final List<String> whiteList;
45
46 private final ClassFileTransformer transformer;
47
48 private final Map<String, String> classNameToError = new HashMap<>();
49
50 public EvictingClassLoader(List<String> whiteList, ClassFileTransformer transformer, ClassLoader parent) {
51 super(parent);
52 this.whiteList = whiteList;
53 whiteList.addAll(Arrays.asList(DEFAULT_EXCLUDED_PACKAGES));
54 this.transformer = transformer;
55 }
56
57
58
59
60 private Class<?> getClass(String name) throws ClassNotFoundException {
61 String file = name.replace('.', File.separatorChar) + ".class";
62 byte[] bytes = null;
63 try {
64 bytes = loadClassData(file);
65 try {
66 transformer.transform(null, name, null, null, bytes);
67 } catch (ClassFormatError error) {
68 classNameToError.put(name, error.getMessage());
69 throw error;
70 } catch (IllegalClassFormatException icfe) {
71 throw new ClassNotFoundException(name, icfe);
72 }
73 Class<?> loaded = super.findLoadedClass(name);
74 if (loaded != null) {
75 return loaded;
76 } else {
77 Class<?> cl = defineClass(name, bytes, 0, bytes.length);
78 resolveClass(cl);
79 return cl;
80 }
81 } catch (IOException ioe) {
82 ioe.printStackTrace();
83 return null;
84 }
85 }
86
87 @Override
88 public Class<?> loadClass(String name) throws ClassNotFoundException {
89 boolean shouldLoad = true;
90 for (String prefix : whiteList) {
91 shouldLoad = shouldLoad && !name.startsWith(prefix);
92 }
93 if (shouldLoad) {
94 return getClass(name);
95 }
96 return super.loadClass(name);
97 }
98
99 @Override
100 public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
101 boolean shouldLoad = true;
102 for (String prefix : whiteList) {
103 shouldLoad = shouldLoad && !name.startsWith(prefix);
104 }
105 if (shouldLoad) {
106 return getClass(name);
107 }
108 return super.loadClass(name, resolve);
109 }
110
111
112
113
114
115
116
117
118
119
120
121 private byte[] loadClassData(String name) throws IOException {
122 InputStream stream = getClass().getClassLoader().getResourceAsStream(name);
123 int size = stream.available();
124 byte[] buff = new byte[size];
125 DataInputStream in = new DataInputStream(stream);
126 in.readFully(buff);
127 in.close();
128 return buff;
129 }
130
131
132
133
134
135
136
137
138
139
140
141
142 public String getError(String className) {
143 return classNameToError.get(className);
144 }
145 }