spring中如何使用覆盖容器中Bean呢?
下文笔者讲述Spring中覆盖Bean的方法分享,如下所示
覆盖Bean的实现思路 1.自定义一个注解 @OverrideBean 2.实现BeanDefinitionRegistryPostProcessor接口,即可实现重写Bean对象例:
import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; /** * 覆盖Spring容器中的Bean */ @Retention(RetentionPolicy.Runtime) @Target(ElementType.TYPE) public @interface OverrideBean { /** * 需要替换的 Bean 的名称 */ String value(); } import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.config.BeanDefinitionHolder; import org.springframework.beans.factory.config.ConfigurablelistableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; import org.springframework.beans.factory.support.GenericBeanDefinition; import org.springframework.boot.autoconfigure.AutoConfigurationPackages; import org.springframework.context.annotation.ClassPathBeanDefinitionScanner; import org.springframework.context.annotation.Configuration; import org.springframework.core.type.filter.AnnotationTypeFilter; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Set; /** * 重写Bean的配置类 */ @Configuration public class OverrideBeanConfiguration implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware { private static final Logger log = LoggerFactory.getLogger(OverrideBeanConfiguration.class); private BeanFactory beanFactory; @Override public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { log.debug("searching for classes annotated with @OverrideBean"); // 自定义 Scanner 扫描 classpath 下的指定注解 ClassPathOverrideBeanAnnotationScanner scanner = new ClassPathOverrideBeanAnnotationScanner(registry); try { // 获取包路径 List<String> packages = AutoConfigurationPackages.get(this.beanFactory); if (log.isDebugEnabled()) { for (String p : packages) { log.debug("Using auto-configuration base package: {}", p); } } // 扫描所有加载的包 scanner.doScan(StringUtils.toStringArray(packages)); } catch (IllegalStateException ex) { log.debug("could not determine auto-configuration package, automatic OverrideBean scanning disabled.", ex); } } @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) throws BeansException { } @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { this.beanFactory = beanFactory; } private static class ClassPathOverrideBeanAnnotationScanner extends ClassPathBeanDefinitionScanner { ClassPathOverrideBeanAnnotationScanner(BeanDefinitionRegistry registry) { super(registry, false); // 设置过滤器。仅扫描 @OverrideBean addIncludeFilter(new AnnotationTypeFilter(OverrideBean.class)); } @Override public Set<BeanDefinitionHolder> doScan(String... basePackages) { List<String> overrideClassNames = new ArrayList<>(); // 扫描全部 package 下 annotationClass 指定的 Bean Set<BeanDefinitionHolder> beanDefinitions = super.doScan(basePackages); GenericBeanDefinition definition; for (BeanDefinitionHolder holder : beanDefinitions) { definition = (GenericBeanDefinition) holder.getBeanDefinition(); // 获取类名,并创建 Class 对象 String className = definition.getBeanClassName(); Class<?> clazz = classNameToClass(className); // 解析注解上的 value OverrideBean annotation = Objects.requireNonNull(clazz).getAnnotation(OverrideBean.class); if (annotation == null || annotation.value().length() == 0) { continue; } // 使用当前加载的 @OverrideBean 指定的 Bean 替换 value 里指定名称的 Bean if (Objects.requireNonNull(getRegistry()).containsBeanDefinition(annotation.value())) { getRegistry().removeBeanDefinition(annotation.value()); getRegistry().registerBeanDefinition(annotation.value(), definition); overrideClassNames.add(clazz.getName()); } } log.info("found override beans: " + overrideClassNames); return beanDefinitions; } // 反射通过 class 名称获取 Class 对象 private Class<?> classNameToClass(String className) { try { return Class.forName(className); } catch (ClassNotFoundException e) { log.error("create instance failed.", e); } return null; } } }
版权声明
本文仅代表作者观点,不代表本站立场。
本文系作者授权发表,未经许可,不得转载。