用 System.Reflection.Emit 来自动生成调用储存过程的实现

类别:Asp 点击:0 评论:0 推荐:

/****************************************************************\
 *
 * 用 System.Reflection.Emit 来自动生成调用储存过程的实现!
 *
 * By http://lostinet.com
 *
 * Copyrights : Not-Reversed
 *
\****************************************************************/

//使用的例子
namespace Lostinet.Sample
{
      using System;
      using System.Data;
      using System.Data.SqlClient;
      using System.Windows.Forms;

      //定义一个接口,用于定义存储过程

      interface INorthwindStoredProcedures
      {
            //定义存储过程对应的方法

            DataSet CustOrderHist(string CustomerID);

            //如果储存过程名字和方法名字不同,应该用SqlAccessAttribute来进行说明
            [SqlAccess("Employee Sales By Country")]
            DataTable EmployeeSalesByCountry(DateTime Beginning_Date,DateTime Ending_Date);

            //...more...

            //MORE Ideas..

            //直接执行SQL语句?
            //[SqlAccess(SqlAccessType.SqlQuery,"SELECT * FROM Employees WHERE EmployeeID=@EmpId")]
            //DataTable SelectEmployee(int EmpId);
      }

      class ConsoleApplication
      {
            [STAThread]
            static void Main(string[] args)
            {
                  using(SqlConnection conn=new SqlConnection("server=(local);trusted_connection=true;database=northwind"))
                  {
                        //一句话就把实现创建了!
                        //需要传如 SqlConnection 和 SqlTransaction
                        //SqlTransaction可以为null

                        //这个好就好在,只要能得到SqlConnection/SqlTransaction就能用这个方法了,所以兼容 Lostinet.Data.SqlScope
                        INorthwindStoredProcedures nsp=(INorthwindStoredProcedures)
                              StoredProcedure.CreateStoredProcedureInterface(typeof(INorthwindStoredProcedures),conn,null);

                        //调用储存过程并且显示

                        ShowData("CustOrderHist ALFKI",nsp.CustOrderHist("ALFKI"));

                        ShowData("Employee Sales By Country",nsp.EmployeeSalesByCountry(new DateTime(1998,1,1),new DateTime(1999,1,1)));

                  }
            }

            static void ShowData(string title,object data)
            {
                  Form f=new Form();
                  f.Width=600;
                  f.Height=480;
                  f.Text=title;

                  DataGrid grid=new DataGrid();                 
                  grid.Dock=DockStyle.Fill;
                  grid.DataSource=data;

                  f.Controls.Add(grid);
                  f.ShowDialog();
            }

      }
}

#region //实现方法(不完整)
namespace Lostinet.Sample
{
      using System;
      using System.Collections;
      using System.Reflection;
      using System.Reflection.Emit;
      using System.Data;
      using System.Data.SqlClient;

      //这个类作为实现的基类,
      //目的是提供储存 SqlConnection/SqlTransaction 和公用的一些方法
      //这个类必须为public,否则无法继承
      //但开发者不会显式访问这个类
      public class SPInterfaceBase : IDisposable
      {
            public SPInterfaceBase()
            {
            }

            public void Dispose()
            {
            }

            //CreateStoredProcedureInterface会把相关的值SqlConnection/SqlTransaction存到这里
            public SqlConnection connection;
            public SqlTransaction transaction;

            //创建一个SqlCommand
            public SqlCommand CreateCommand(string spname)
            {
                  SqlCommand cmd=new SqlCommand(spname,connection,transaction);
                  cmd.CommandType=CommandType.StoredProcedure;
                  //TODO:
                  //cmd.Parameters.Add("@ReturnValue",...
                  return cmd;
            }

            //由 Type 推算出 SqlDbType , 未完成
            SqlDbType GetSqlDbType(Type type)
            {
                  //TODO:switch(type)...

                  return SqlDbType.NVarChar;
            }

            //定义参数
            public void DefineParameter(SqlCommand cmd,string name,Type type,ParameterDirection direction)
            {
                  SqlParameter param=new SqlParameter("@"+name,GetSqlDbType(type));
                  param.Direction=direction;
                  cmd.Parameters.Add(param);
            }

            //在SqlCommand执行前设置参数值
            public void SetParameter(SqlCommand cmd,string name,object value)
            {
                  cmd.Parameters["@"+name].Value=(value==null?DBNull.Value:value);
            }
            //在SqlCommand执行后取得参数值
            public object GetParameter(SqlCommand cmd,string name)
            {
                  return cmd.Parameters["@"+name].Value;
            }

            //根据不同的返回值执行不同的操作

            public SqlDataReader ExecuteDataReader(SqlCommand cmd)
            {
                  return cmd.ExecuteReader();
            }
            public object ExecuteScalar(SqlCommand cmd)
            {
                  return cmd.ExecuteScalar();
            }
            public void ExecuteNonQuery(SqlCommand cmd)
            {
                  cmd.ExecuteNonQuery();
            }
            public DataSet ExecuteDataSet(SqlCommand cmd)
            {
                  DataSet ds=new DataSet();
                  using(SqlDataAdapter sda=new SqlDataAdapter(cmd))
                  {
                        sda.Fill(ds);
                  }
                  return ds;
            }
            public DataTable ExecuteDataTable(SqlCommand cmd)
            {
                  DataTable table=new DataTable();
                  using(SqlDataAdapter sda=new SqlDataAdapter(cmd))
                  {
                        sda.Fill(table);
                  }
                  return table;
            }
            public DataRow ExecuteDataRow(SqlCommand cmd)
            {
                  DataTable table=ExecuteDataTable(cmd);
                  if(table.Rows.Count==0)
                        return null;
                  return table.Rows[0];
            }
      }


      public class StoredProcedure
      {
            static public object CreateStoredProcedureInterface(Type interfaceType,SqlConnection connection,SqlTransaction transaction)
            {
                  //检查参数
                  if(interfaceType==null)throw(new ArgumentNullException("interfaceType"));
                  if(!interfaceType.IsInterface)
                        throw(new ArgumentException("argument is not interface","interfaceType"));
                  if(connection==null)throw(new ArgumentNullException("connection"));
                  if(transaction!=null)
                  {
                        if(transaction.Connection!=connection)
                              throw(new ArgumentException("transaction.Connection!=connection","transaction"));
                  }

                  //创建StoredProcedure

                  StoredProcedure spemit=new StoredProcedure();
                  spemit.interfaceType=interfaceType;
                  spemit.connection=connection;
                  spemit.transaction=transaction;

                  //创建
                  return spemit.CreateInstance();
            }

            //用于储存已创建的类型
            static Hashtable EmittedTypes=new Hashtable();
           
            Type interfaceType;
            SqlConnection connection;
            SqlTransaction transaction;

            private StoredProcedure()
            {
            }

            object CreateInstance()
            {
                  lock(interfaceType)
                  {
                        //如果没有创建具体的实现,则创建它

                        if(emittedType==null)
                        {
                              emittedType=(Type)EmittedTypes[interfaceType];

                              if(emittedType==null)
                              {
                                    CreateType();

                                    //储存已创建类型
                                    EmittedTypes[interfaceType]=emittedType;
                              }
                        }
                  }

                  //创建具体的实例
                  SPInterfaceBase spi=(SPInterfaceBase)Activator.CreateInstance(emittedType);

                  //设置SqlConnection/SqlTransaction
                  spi.connection=connection;
                  spi.transaction=transaction;

                  return spi;
            }

            Type emittedType;

            TypeBuilder typeBuilder;

            //创建类型
            void CreateType()
            {
                  //创建 Assembly
                  //AssemblyBuilderAccess.Run-表示只用于运行,不在磁盘上保存
                  AssemblyName an=new AssemblyName();
                  an.Name="Assembly."+interfaceType.FullName+".Implementation";
                  AssemblyBuilder asmBuilder=AppDomain.CurrentDomain.DefineDynamicAssembly(an,AssemblyBuilderAccess.Run);

                  //创建Module
                  ModuleBuilder mdlBuilder=asmBuilder.DefineDynamicModule("Module."+interfaceType.FullName+".Implementation");

                  //创建Type,该类型继承 SPInterfaceBase
                  typeBuilder=mdlBuilder.DefineType(interfaceType.FullName+".Implementation",TypeAttributes.Class,typeof(SPInterfaceBase));

                  //实现所有的接口方法
                  EmitInterface(interfaceType);

                  //如果interfaceType是基于其他接口的
                  foreach(Type subinterface in interfaceType.GetInterfaces())
                  {
                        //IDisposable不需要实现,由SPInterfaceBase实现了
                        if(subinterface==typeof(IDisposable))
                              continue;

                        EmitInterface(subinterface);
                  }


                  emittedType=typeBuilder.CreateType();
            }

            void EmitInterface(Type type)
            {
                  //实现接口
                  typeBuilder.AddInterfaceImplementation(type);

                  //列出接口的成员
                  foreach(MemberInfo member in type.GetMembers(BindingFlags.Instance|BindingFlags.Public))
                  {
                        //约定-成员必须是方法,不能有属性啊,事件之类的
                        if(member.MemberType!=MemberTypes.Method)
                              throw(new Exception("Could Not Emit "+member.MemberType+" Automatically!"));

                        //取得接口中定义的方法
                        MethodInfo method=(MethodInfo)member;

                        //计算新方法的属性,在原来方法的属性上复制过来,并且不是Public/Abstract,加上Private
                        MethodAttributes methodattrs=method.Attributes;
                        methodattrs&=~(MethodAttributes.Public|MethodAttributes.Abstract);
                        methodattrs|=MethodAttributes.Private;

                       
                        ParameterInfo[] paramInfos=method.GetParameters();
                        int paramlength=paramInfos.Length;

                        //取得参数的类型数组
                        Type[] paramTypes=new Type[paramlength];
                        for(int i=0;i<paramlength;i++)
                        {
                              paramTypes[i]=paramInfos[i].ParameterType;
                        }

                        //在typeBuilder上建立新方法,参数类型与返回类型都与接口上的方法一致
                        MethodBuilder mthBuilder=typeBuilder.DefineMethod(method.Name,methodattrs,method.CallingConvention,method.ReturnType,paramTypes);

                        //复制新方法上的参数的名字和属性
                        for(int i=0;i<paramlength;i++)
                        {
                              ParameterInfo pi=paramInfos[i];
                              //对于Instance,参数position由1开始
                              mthBuilder.DefineParameter(i+1,pi.Attributes,pi.Name);
                        }

                        //指定新方法是实现接口的方法的。
                        typeBuilder.DefineMethodOverride(mthBuilder,method);

                        //在类型上定义一个字段,这个字段用于储存被方法使用的SqlCommand
                        FieldBuilder field_cmd=typeBuilder.DefineField("_cmd_"+method.Name,typeof(SqlCommand),FieldAttributes.Private);

                        //ILGenerator 是用于生成实现代码的对象
                        ILGenerator ilg=mthBuilder.GetILGenerator();

                        //定义临时变量
                        LocalBuilder local_res=ilg.DeclareLocal(typeof(object));

                        //定义一个用于跳转的Label
                        Label label_cmd_ready=ilg.DefineLabel();

                        //this._cmd_MethodName
                        ilg.Emit(OpCodes.Ldarg_0);      //this
                        ilg.Emit(OpCodes.Ldfld,field_cmd);//._cmd_MethodName

                        //if(this._cmd_MethodName!=null) 跳到 label_cmd_ready
                        ilg.Emit(OpCodes.Brtrue,label_cmd_ready);

                        //如果this._cmd_MethodName为null,则运行下面代码来创建SqlCommand

 

                        //this._cmd_MethodName=this.CreateCommand("MethodName");
                        ilg.Emit(OpCodes.Ldarg_0);

                        //this.CreateCommand
                        ilg.Emit(OpCodes.Ldarg_0);//参数0
                        ilg.Emit(OpCodes.Ldstr,SqlAccessAttribute.GetSPName(method));//参数1
                        //调用
                        ilg.Emit(OpCodes.Callvirt,typeof(SPInterfaceBase).GetMethod("CreateCommand",BindingFlags.Instance|BindingFlags.Public));

                        ilg.Emit(OpCodes.Stfld,field_cmd);// ._cmd_MethodName=

                        //this.DefineParameter(...)
                        if(paramlength!=0)
                        {
                              //取得DefineParameter的引用
                              MethodInfo method_DefineParameter=typeof(SPInterfaceBase).GetMethod("DefineParameter",BindingFlags.Instance|BindingFlags.Public);
                             
                              for(int i=0;i<paramlength;i++)
                              {
                                    //取得各参数
                                    ParameterInfo pi=paramInfos[i];

                                    //this.DefineParameter(this._cmd_MethodName,"ParameterName",typeof(ParameterType),ParameterDirection.Xxx);

                                    //参数0 - this
                                    ilg.Emit(OpCodes.Ldarg_0);

                                    //参数1 - this._cmd_MethodName
                                    ilg.Emit(OpCodes.Ldarg_0);
                                    ilg.Emit(OpCodes.Ldfld,field_cmd);

                                    //参数2 - "ParameterName"
                                    ilg.Emit(OpCodes.Ldstr,pi.Name);

                                    //参数3 - typeof(ParameterType)
                                    ilg.Emit(OpCodes.Ldtoken,pi.ParameterType);

                                    //参数4 - ParameterDirection.Xxx
                                    if(pi.ParameterType.IsByRef)
                                    {
                                          ilg.Emit(OpCodes.Ldc_I4,(int)ParameterDirection.InputOutput);
                                    }
                                    else if(pi.IsOut)
                                    {
                                          ilg.Emit(OpCodes.Ldc_I4,(int)ParameterDirection.Output);
                                    }
                                    else
                                    {
                                          ilg.Emit(OpCodes.Ldc_I4,(int)ParameterDirection.Input);
                                    }

                                    //调用DefineParameter
                                    ilg.Emit(OpCodes.Callvirt,method_DefineParameter);
                              }
                        }
                        //到这里 _cmd_CommandName 已经 OK 了。

                        //设置label_cmd_ready就指这里
                        ilg.MarkLabel(label_cmd_ready);
                       
                        //cmd!=null now.

                        if(paramlength!=0)
                        {
                              //现在要把方法的参数的值设置到SqlParameter上

                              MethodInfo method_SetParameter=typeof(SPInterfaceBase).GetMethod("SetParameter",BindingFlags.Instance|BindingFlags.Public);
                             
                              for(int i=0;i<paramlength;i++)
                              {
                                    ParameterInfo pi=paramInfos[i];

                                    //如果参数是 out 的,则不需要设置
                                    if(!pi.ParameterType.IsByRef&&pi.IsOut)
                                          continue;

                                    //this.SetParameter(this._cmd_MethodName,"ParameterName",ParameterName);

                                    ilg.Emit(OpCodes.Ldarg_0);
                                   
                                    ilg.Emit(OpCodes.Ldarg_0);
                                    ilg.Emit(OpCodes.Ldfld,field_cmd);

                                    ilg.Emit(OpCodes.Ldstr,pi.Name);

                                    //取得参数值,如果参数为ValueType,则Box到Object
                                    ilg.Emit(OpCodes.Ldarg,i+1);
                                    if(pi.ParameterType.IsValueType)
                                          ilg.Emit(OpCodes.Box,pi.ParameterType);

                                    ilg.Emit(OpCodes.Callvirt,method_SetParameter);
                              }
                        }

                        //现在要执行储存过程(执行SqlCommand)了

                        //这里根据返回值类型判断怎样执行SqlCommand

                        Type returnType=method.ReturnType;

                        //如果是 void 的,则不需要返回值
                        bool nores=returnType==typeof(void);

                        MethodInfo method_Execute=null;

                        if(nores)
                        {
                              //不需要返回值
                              method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteNonQuery",BindingFlags.Instance|BindingFlags.Public);
                        }
                        else if(returnType==typeof(object))
                        {
                              //返回object
                              method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteScalar",BindingFlags.Instance|BindingFlags.Public);
                        }
                        else if(returnType==typeof(DataSet))
                        {
                              //返回DataSet
                              method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteDataSet",BindingFlags.Instance|BindingFlags.Public);
                        }
                        else if(returnType==typeof(DataTable))
                        {
                              //返回DataTable
                              method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteDataTable",BindingFlags.Instance|BindingFlags.Public);
                        }
                        else if(returnType==typeof(DataRow))
                        {
                              //返回DataRow
                              method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteDataRow",BindingFlags.Instance|BindingFlags.Public);
                        }
                        else
                        {
                              //返回其他类型
                              foreach(Type retInterface in returnType.GetInterfaces())
                              {
                                    //如果是返回IDataReader
                                    if(retInterface==typeof(IDataReader))
                                    {
                                          //只支持SqlDataReader
                                          if(!returnType.IsAssignableFrom(typeof(SqlDataReader)))
                                                throw(new Exception("SqlDataReader could not convert to "+returnType.FullName));

                                          method_Execute=typeof(SPInterfaceBase).GetMethod("ExecuteDataReader",BindingFlags.Instance|BindingFlags.Public);
                                          break;
                                    }
                              }
                        }

                        //如果找不到适合的策略,
                        if(method_Execute==null)
                        {
                              //TODO:当然,这里应该有返回Int32,String,...的,不过懒得再写了。

                              //抛出异常,提示不支持该返回类型,要作者改改:)
                              throw(new NotSupportedException("NotSupport ReturnType:"+returnType.FullName));
                        }

                        //this.ExecuteXXX(this._cmd_MethodName)
                        ilg.Emit(OpCodes.Ldarg_0);
                        ilg.Emit(OpCodes.Ldarg_0);
                        ilg.Emit(OpCodes.Ldfld,field_cmd);
                        ilg.Emit(OpCodes.Callvirt,method_Execute);

                        //如果有返回值的,则是
                        //local_res=this.ExecuteXXX(this._cmd_MethodName)
                        if(!nores)
                        {
                              if(returnType.IsValueType)
                                    ilg.Emit(OpCodes.Box,returnType);
                              ilg.Emit(OpCodes.Stloc,local_res);
                        }

                        if(paramlength!=0)
                        {
                              //这里处理ref/out的参数
                              MethodInfo method_GetParameter=typeof(SPInterfaceBase).GetMethod("GetParameter",BindingFlags.Instance|BindingFlags.Public);
                             
                              for(int i=0;i<paramlength;i++)
                              {
                                    ParameterInfo pi=paramInfos[i];

                                    //如果不是ref/out则跳过
                                    if(!pi.ParameterType.IsByRef&&!pi.IsOut)
                                          continue;

                                    //ParameterName=this.GetParameter(this._cmd_Methodname,"ParameterName")
                                    ilg.Emit(OpCodes.Ldarg_0);
                                   
                                    ilg.Emit(OpCodes.Ldarg_0);
                                    ilg.Emit(OpCodes.Ldfld,field_cmd);

                                    ilg.Emit(OpCodes.Ldstr,pi.Name);

                                    ilg.Emit(OpCodes.Callvirt,method_GetParameter);

                                    //如果类型是值类型,则需要 Unbox
                                    if(pi.ParameterType.IsValueType)
                                          ilg.Emit(OpCodes.Unbox,pi.ParameterType);

                                    ilg.Emit(OpCodes.Starg,i+1);                                   
                              }
                        }

                        //如果是 void , 则直接 return;
                        //否者是 return local_res , 如果返回值类型是ValueType,则需要Unbox
                        if(!nores)
                        {
                              ilg.Emit(OpCodes.Ldloc,local_res);
                              if(returnType.IsValueType)
                                    ilg.Emit(OpCodes.Unbox,returnType);
                        }
                        ilg.Emit(OpCodes.Ret);

//                        //throw(new NotImplementedException());
//                        ilg.Emit(OpCodes.Newobj,typeof(NotImplementedException).GetConstructor(new Type[0]));
//                        ilg.Emit(OpCodes.Throw);

                  }
            }
      }

      public enum SqlAccessType
      {
            StoredProcedure
            //TODO:
            //,SqlQuery
      }

      [AttributeUsage(AttributeTargets.Method)]
      public class SqlAccessAttribute : Attribute
      {
            string _sp;

            public SqlAccessAttribute(string spname)
            {
                  _sp=spname;
            }

            public string StoreProcedure
            {
                  get
                  {
                        return _sp;
                  }
            }

            static public string GetSPName(MethodInfo method)
            {
                  if(method==null)throw(new ArgumentNullException("method"));

                  object[] attrs=method.GetCustomAttributes(typeof(SqlAccessAttribute),false);
                  if(attrs==null||attrs.Length==0)
                        return method.Name;

                  return ((SqlAccessAttribute)attrs[0]).StoreProcedure;
            }

            //TODO:
//            public SqlAccessAttribute(SqlAccessType type,string text)
//            {
//
//            }
      }
     
}
#endregion

本文地址:http://com.8s8s.com/it/it7935.htm